diff --git a/tools/bazel.rc b/.bazelrc similarity index 88% rename from tools/bazel.rc rename to .bazelrc index 1fdf51f53e29c7111cf89c016400b710051cf9c6..17285afdb381018d0054e771475327b1f7ed9866 100644 --- a/tools/bazel.rc +++ b/.bazelrc @@ -25,12 +25,14 @@ build --define framework_shared_object=true # If you would like to use a local MKL instead of downloading, please set the # environment variable "TF_MKL_ROOT" every time before build. build:mkl --define=build_with_mkl=true --define=enable_mkl=true +build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl -c opt # This config option is used to enable MKL-DNN open source library only, # without depending on MKL binary version. build:mkl_open_source_only --define=build_with_mkl_dnn_only=true build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true +build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0 build:download_clang --crosstool_top=@local_config_download_clang//:toolchain build:download_clang --define=using_clang=true @@ -76,10 +78,9 @@ build:nonccl --define=no_nccl_support=true build --define=use_fast_cpp_protos=true build --define=allow_oversize_protos=true -build --define=grpc_no_ares=true build --spawn_strategy=standalone -build --genrule_strategy=standalone +build --strategy=Genrule=standalone build -c opt # Other build flags. @@ -89,7 +90,21 @@ build --define=grpc_no_ares=true build:dynamic_kernels --define=dynamic_loaded_kernels=true build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS +# Build TF with C++ 17 features. +build:c++17 --cxxopt=-std=c++1z +build:c++17 --cxxopt=-stdlib=libc++ +build:c++1z --cxxopt=-std=c++1z +build:c++1z --cxxopt=-stdlib=libc++ + # Default paths for TF_SYSTEM_LIBS build --define=PREFIX=/usr build --define=LIBDIR=$(PREFIX)/lib build --define=INCLUDEDIR=$(PREFIX)/include + +# Default options should come above this line + +# Options from ./configure +try-import %workspace%/.tf_configure.bazelrc + +# Put user-specific options in .bazelrc.user +try-import %workspace%/.bazelrc.user diff --git a/.gitignore b/.gitignore index 90324058600bee46af56e49028977971848a80de..e1d352c238a1b2d4febe0f5d4a30cfa0c942f7e7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ .DS_Store .ipynb_checkpoints node_modules -/.bazelrc +/.bazelrc.user /.tf_configure.bazelrc /bazel-* /bazel_pip diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4a296f265f7b9521c46d350cec26ff199f43eb6c..b978f89f9e1d79dd4f7481711a59c2b94e8bf01b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -150,41 +150,45 @@ may exist in your changes. There are two ways to run TensorFlow unit tests. -1. Using tools and libraries installed directly on your system. +1. Using tools and libraries installed directly on your system. - Refer to the - [CPU-only developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel) and - [GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel-gpu) - for the required packages. Alternatively, use the said - [Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g., - `tensorflow/tensorflow:nightly-devel` and `tensorflow/tensorflow:nightly-devel-gpu` - for development to avoid installing the packages directly on your system. + Refer to the + [CPU-only developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel) + and + [GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel-gpu) + for the required packages. Alternatively, use the said + [Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g., + `tensorflow/tensorflow:nightly-devel` and + `tensorflow/tensorflow:nightly-devel-gpu` for development to avoid + installing the packages directly on your system (in which case remember to + change directory from `/root` to `/tensorflow` once you get into the running + container so `bazel` can find the `tensorflow` workspace). - Once you have the packages installed, you can run a specific unit test in - bazel by doing as follows: + Once you have the packages installed, you can run a specific unit test in + bazel by doing as follows: - If the tests are to be run on GPU, add CUDA paths to LD_LIBRARY_PATH and add - the `cuda` option flag + If the tests are to be run on GPU, add CUDA paths to LD_LIBRARY_PATH and add + the `cuda` option flag - ```bash - export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" + ```bash + export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" - export flags="--config=opt --config=cuda -k" - ``` + export flags="--config=opt --config=cuda -k" + ``` - For example, to run all tests under tensorflow/python, do: + For example, to run all tests under tensorflow/python, do: - ```bash - bazel test ${flags} //tensorflow/python/... - ``` + ```bash + bazel test ${flags} //tensorflow/python/... + ``` -2. Using [Docker](https://www.docker.com) and TensorFlow's CI scripts. +2. Using [Docker](https://www.docker.com) and TensorFlow's CI scripts. - ```bash - # Install Docker first, then this will build and run cpu tests - tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/... - ``` - - See - [TensorFlow Builds](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/ci_build) for details. + ```bash + # Install Docker first, then this will build and run cpu tests + tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/... + ``` + See + [TensorFlow Builds](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/ci_build) + for details. diff --git a/README.md b/README.md index 044174947a094d43a51f7140dd40ec0f17801d40..96a8ecf4f693d5634da63f4ecc6f4e9c35751f5b 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,8 @@ organization for the purposes of conducting machine learning and deep neural networks research. The system is general enough to be applicable in a wide variety of other domains, as well. -TensorFlow provides stable Python API and C APIs as well as without API backwards compatibility guarantee like C++, Go, Java, JavaScript and Swift. +TensorFlow provides stable Python and C APIs as well as non-guaranteed backwards +compatible API's for C++, Go, Java, JavaScript and Swift. Keep up to date with release announcements and security updates by subscribing to @@ -57,21 +58,24 @@ Simply run `pip install tf-nightly` or `pip install tf-nightly-gpu` in a clean environment to install the nightly TensorFlow build. We support CPU and GPU packages on Linux, Mac, and Windows. - #### *Try your first TensorFlow program* + ```shell $ python ``` + ```python >>> import tensorflow as tf >>> tf.enable_eager_execution() ->>> tf.add(1, 2) +>>> tf.add(1, 2).numpy() 3 >>> hello = tf.constant('Hello, TensorFlow!') >>> hello.numpy() 'Hello, TensorFlow!' ``` -Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). + +Learn more examples about how to do specific tasks in TensorFlow at the +[tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). ## Contribution guidelines @@ -113,11 +117,12 @@ The TensorFlow project strives to abide by generally accepted best practices in Build Type | Status | Artifacts ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA -**IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | TBA -**IBM ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) -**IBM ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) +**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/) +**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) +**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) +**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) -**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.11.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp27-cp27mu-linux_x86_64.whl)
[1.11.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp34-cp34m-linux_x86_64.whl)
[1.11.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp35-cp35m-linux_x86_64.whl)
[1.11.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp36-cp36m-linux_x86_64.whl) +**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.12.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp27-cp27mu-linux_x86_64.whl)
[1.12.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp34-cp34m-linux_x86_64.whl)
[1.12.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp35-cp35m-linux_x86_64.whl)
[1.12.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp36-cp36m-linux_x86_64.whl) ## For more information diff --git a/RELEASE.md b/RELEASE.md index b13b071bd6cf4d3a260c8e248a67d23e1a688498..0a56e6909870e398c9d6349576cd2f8e6734f072 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -7,6 +7,8 @@ Serving. * Keras models now support evaluating with a `tf.data.Dataset`. * TensorFlow binaries are built with XLA support linked in by default. +* Ignite Dataset added to contrib/ignite that allows to work with Apache + Ignite. ## Bug Fixes and Other Changes @@ -280,50 +282,76 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A ## Bug Fixes and Other Changes -* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. -* Layered variable names have changed in the following conditions: - * Using `tf.keras.layers` with custom variable scopes. - * Using `tf.layers` in a subclassed `tf.keras.Model` class. See - [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details -* `tf.data`: - * `Dataset.from_generator()` now accepts an `args` list, in order to create nested generators. - * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed. - * `tf.contrib.data.sample_from_datasets()` and `tf.contrib.data.choose_from_datasets()` make it easier to sample or deterministically choose elements from multiple datasets. - * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings, and two infrequently used arguments removed. - * (C++) `DatasetBase::DebugString()` is now `const`. - * (C++) `DatasetBase::MakeIterator()` has been renamed to `DatasetBase::MakeIteratorInternal()`. - * (C++) `IteratorBase::Initialize()` method was added to support raising errors during iterator construction. -* Eager Execution: - * Added the ability to pause recording operations for gradient computation via `tf.GradientTape.stop_recording`. - * Updated documentation, introductory notebooks. -* `tf.keras`: - * Move Keras code out of _impl folder and remove API files. - * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. - * Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods. -* TensorFlow Debugger (tfdbg) CLI: fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB). -* `tf.contrib`: - * `tf.contrib.framework.zero_initializer` supports ResourceVariable. - * Adding "constrained_optimization" to tensorflow/contrib. -* Other: - * Add GCS Configuration Ops. - * Changing signature of `MakeIterator` to enable propagating error status. - * KL divergence for two Dirichlet distributions. - * More consistent GcsFileSystem behavior for certain reads past EOF. - * Update benchmark for tf.scan to match ranges across eager and graph modes. - * Fixed bug in `tf.reduce_prod gradient` for complex dtypes. - * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)"). - * Benchmark for tf.scan in graph and eager modes. - * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. - * Making ids unique in `nn.embedding_lookup_sparse`. This helps to reduce RPC calls for looking up the embeddings when there are repeated ids in the batch. - * Support indicator column in boosted trees. - * Prevent `tf.gradients()` from backpropagating through integer tensors. - * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`. - * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary. - * Added `tf.train.Checkpoint` for reading/writing object-based checkpoints. - * Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product. - * Allow LinearOperator to broadcast. - * SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other. - +* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. +* Layered variable names have changed in the following conditions: + * Using `tf.keras.layers` with custom variable scopes. + * Using `tf.layers` in a subclassed `tf.keras.Model` class. See + [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) + for more details +* `tf.data`: + * `Dataset.from_generator()` now accepts an `args` list, in order to + create nested generators. + * `Dataset.list_files()` now produces deterministic results when + `shuffle=False` or a `seed` is passed. + * `tf.contrib.data.sample_from_datasets()` and + `tf.contrib.data.choose_from_datasets()` make it easier to sample or + deterministically choose elements from multiple datasets. + * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted + strings, and two infrequently used arguments removed. + * (C++) `DatasetBase::DebugString()` is now `const`. + * (C++) `DatasetBase::MakeIterator()` has been renamed to + `DatasetBase::MakeIteratorInternal()`. + * (C++) `IteratorBase::Initialize()` method was added to support raising + errors during iterator construction. +* Eager Execution: + * Added the ability to pause recording operations for gradient computation + via `tf.GradientTape.stop_recording`. + * Updated documentation, introductory notebooks. +* `tf.keras`: + * Move Keras code out of _impl folder and remove API files. + * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. + * Enable dataset iterators to be passed to `tf.keras.Model` training/eval + methods. +* TensorFlow Debugger (tfdbg) CLI: fix an issue in which the TensorBoard + Debugger Plugin could not handle total source file size exceeding gRPC + message size limit (4 MB). +* `tf.contrib`: + * `tf.contrib.framework.zero_initializer` supports ResourceVariable. + * Adding "constrained_optimization" to tensorflow/contrib. +* Other: + * Add GCS Configuration Ops. + * Changing signature of `MakeIterator` to enable propagating error status. + * KL divergence for two Dirichlet distributions. + * More consistent GcsFileSystem behavior for certain reads past EOF. + * Update benchmark for tf.scan to match ranges across eager and graph + modes. + * Fixed bug in `tf.reduce_prod gradient` for complex dtypes. + * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), + which would previously raise an error. This will correspond to an + attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only + be accessed indirectly (e.g. through getattr and setattr). To set this + up the user will first need to explicitly add the variable to the hparam + object (e.g. "hparams.add_hparam(name='a.b', value=0.0)"). + * Benchmark for tf.scan in graph and eager modes. + * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. + * Making ids unique in `nn.embedding_lookup_sparse`. This helps to reduce + RPC calls for looking up the embeddings when there are repeated ids in + the batch. + * Support indicator column in boosted trees. + * Prevent `tf.gradients()` from backpropagating through integer tensors. + * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`. + * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports + arbitrary. + * Added `tf.train.Checkpoint` for reading/writing object-based + checkpoints. + * Added LinearOperatorKronecker, a dense-free implementation of the + Kronecker Product. + * Allow LinearOperator to broadcast. + * SavedModelBuilder will now deduplicate asset names that point to files + with the same basename and the same contents. Note that this may result + in new asset files included in SavedModels in cases where assets with + the same name but different contents were previously overwriting each + other. ## Thanks to our Contributors @@ -821,7 +849,7 @@ answered questions, and were part of inspiring discussions. * Remove `tf.contrib.data.Iterator.from_dataset()` method. Use `Dataset.make_initializable_iterator()` instead. * Remove seldom used and unnecessary `tf.contrib.data.Iterator.dispose_op()`. -* Reorder some TFGAN loss functions in a non-backwards compatible way. +* Reorder some TF-GAN loss functions in a non-backwards compatible way. ## Known Issues * In Python 3, `Dataset.from_generator()` does not support Unicode strings. diff --git a/WORKSPACE b/WORKSPACE index 7cc08e0164a202581ad7ebbe107a9e19410e70e4..9f07b9fd47136d058cc4039ed6948db539485039 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,14 +1,14 @@ workspace(name = "org_tensorflow") -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file") http_archive( name = "io_bazel_rules_closure", - sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae", - strip_prefix = "rules_closure-dbb96841cc0a5fb2664c37822803b06dab20c7d1", + sha256 = "43c9b882fa921923bcba764453f4058d102bece35a37c9f6383c713004aacff1", + strip_prefix = "rules_closure-9889e2348259a5aad7e805547c1a0cf311cfcd91", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", - "https://github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", # 2018-04-13 + "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/9889e2348259a5aad7e805547c1a0cf311cfcd91.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/9889e2348259a5aad7e805547c1a0cf311cfcd91.tar.gz", # 2018-12-21 ], ) @@ -16,38 +16,52 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") closure_repositories() -http_archive( - name = "base_images_docker", - sha256 = "e2b1b7254270bb7605e814a9dbf6d1e4ae04a11136ff1714fbfdabe3f87f7cf9", - strip_prefix = "base-images-docker-12801524f867e657fbb5d1a74f31618aff181ac6", - urls = ["https://github.com/GoogleCloudPlatform/base-images-docker/archive/12801524f867e657fbb5d1a74f31618aff181ac6.tar.gz"], -) +load("//third_party/toolchains/preconfig/generate:archives.bzl", + "bazel_toolchains_archive") -http_archive( - name = "bazel_toolchains", - sha256 = "15b5858b1b5541ec44df31b94c3b8672815b31d71215a98398761ea9f4c4eedb", - strip_prefix = "bazel-toolchains-6200b238c9c2d137c0d9a7262c80cc71d98e692b", - urls = [ - "https://github.com/bazelbuild/bazel-toolchains/archive/6200b238c9c2d137c0d9a7262c80cc71d98e692b.tar.gz", - ], +bazel_toolchains_archive() + +load( + "@bazel_toolchains//repositories:repositories.bzl", + bazel_toolchains_repositories = "repositories", ) -http_archive( - name = "io_bazel_rules_docker", - sha256 = "29d109605e0d6f9c892584f07275b8c9260803bf0c6fcb7de2623b2bedc910bd", - strip_prefix = "rules_docker-0.5.1", - urls = ["https://github.com/bazelbuild/rules_docker/archive/v0.5.1.tar.gz"], +bazel_toolchains_repositories() + +load( + "@io_bazel_rules_docker//repositories:repositories.bzl", + container_repositories = "repositories", ) -load("//third_party/toolchains/preconfig/generate:workspace.bzl", "remote_config_workspace") +container_repositories() + +load("//third_party/toolchains/preconfig/generate:workspace.bzl", + "remote_config_workspace") remote_config_workspace() +# Apple and Swift rules. +http_archive( + name = "build_bazel_rules_apple", + sha256 = "73b4980a318d203d3307f850e27e66ec5cc8d223147a3475a6f11597eb6438a5", + strip_prefix = "rules_apple-0.13.0", + urls = ["https://github.com/bazelbuild/rules_apple/archive/0.13.0.tar.gz"], +) +http_file( + name = "xctestrunner", + executable = 1, + urls = ["https://github.com/google/xctestrunner/releases/download/0.2.6/ios_test_runner.par"], +) +load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies") +apple_rules_dependencies() +load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies") +swift_rules_dependencies() + # We must check the bazel version before trying to parse any other BUILD # files, in case the parsing of those build files depends on the bazel # version we require here. load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") -check_bazel_version_at_least("0.15.0") +check_bazel_version_at_least("0.19.0") load("//tensorflow:workspace.bzl", "tf_workspace") @@ -108,4 +122,3 @@ http_archive( "http://download.tensorflow.org/models/speech_commands_v0.01.zip", ], ) - diff --git a/tensorflow/opensource_only/arm_compiler.BUILD b/arm_compiler.BUILD similarity index 100% rename from tensorflow/opensource_only/arm_compiler.BUILD rename to arm_compiler.BUILD diff --git a/configure.py b/configure.py index 6c905a0be3d685b5921dfbc5bddfbe6471a82625..3eb09a1ae905b70dc5d02fab7c316f73c79633dd 100644 --- a/configure.py +++ b/configure.py @@ -33,7 +33,7 @@ except ImportError: from distutils.spawn import find_executable as which # pylint: enable=g-import-not-at-top -_DEFAULT_CUDA_VERSION = '9.0' +_DEFAULT_CUDA_VERSION = '10.0' _DEFAULT_CUDNN_VERSION = '7' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' _DEFAULT_CUDA_PATH = '/usr/local/cuda' @@ -55,6 +55,12 @@ NCCL_LIB_PATHS = [ 'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', '' ] +# List of files to be configured for using Bazel on Apple platforms. +APPLE_BAZEL_FILES = [ + 'tensorflow/lite/experimental/objc/BUILD', + 'tensorflow/lite/experimental/swift/BUILD' +] + if platform.machine() == 'ppc64le': _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/' else: @@ -255,18 +261,7 @@ def setup_python(environ_cp): def reset_tf_configure_bazelrc(): """Reset file that contains customized config settings.""" open(_TF_BAZELRC, 'w').close() - bazelrc_path = os.path.join(_TF_WORKSPACE_ROOT, '.bazelrc') - - data = [] - if os.path.exists(bazelrc_path): - with open(bazelrc_path, 'r') as f: - data = f.read().splitlines() - with open(bazelrc_path, 'w') as f: - for l in data: - if _TF_BAZELRC_FILENAME in l: - continue - f.write('%s\n' % l) - f.write('import %%workspace%%/%s\n' % _TF_BAZELRC_FILENAME) + def cleanup_makefile(): """Delete any leftover BUILD files from the Makefile build. @@ -488,11 +483,14 @@ def check_bazel_version(min_version, max_version): if curr_version_int < min_version_int: print('Please upgrade your bazel installation to version %s or higher to ' 'build TensorFlow!' % min_version) - sys.exit(0) - if curr_version_int > max_version_int: + sys.exit(1) + if (curr_version_int > max_version_int and + 'TF_IGNORE_MAX_BAZEL_VERSION' not in os.environ): print('Please downgrade your bazel installation to version %s or lower to ' - 'build TensorFlow!' % max_version) - sys.exit(0) + 'build TensorFlow! To downgrade: download the installer for the old ' + 'version (from https://github.com/bazelbuild/bazel/releases) then ' + 'run the installer.' % max_version) + sys.exit(1) return curr_version @@ -794,8 +792,7 @@ def set_gcc_host_compiler_path(environ_cp): environ_cp, var_name='GCC_HOST_COMPILER_PATH', var_default=default_gcc_host_compiler_path, - ask_for_var= - 'Please specify which gcc should be used by nvcc as the host compiler.', + ask_for_var='Please specify which gcc should be used by nvcc as the host compiler.', check_success=os.path.exists, error_msg='Invalid gcc path. %s cannot be found.', ) @@ -1246,6 +1243,7 @@ def set_tf_nccl_install_path(environ_cp): environ_cp['TF_NCCL_VERSION'] = tf_nccl_version write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version) + def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. @@ -1282,13 +1280,15 @@ def set_tf_cuda_compute_capabilities(environ_cp): ask_cuda_compute_capabilities = ( 'Please specify a list of comma-separated ' - 'Cuda compute capabilities you want to ' + 'CUDA compute capabilities you want to ' 'build with.\nYou can find the compute ' 'capability of your device at: ' 'https://developer.nvidia.com/cuda-gpus.\nPlease' ' note that each additional compute ' 'capability significantly increases your ' - 'build time and binary size. [Default is: %s]: ' % + 'build time and binary size, and that ' + 'TensorFlow only supports compute ' + 'capabilities >= 3.5 [Default is: %s]: ' % default_cuda_compute_capabilities) tf_cuda_compute_capabilities = get_from_env_or_user_or_default( environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES', @@ -1301,12 +1301,14 @@ def set_tf_cuda_compute_capabilities(environ_cp): for compute_capability in tf_cuda_compute_capabilities.split(','): m = re.match('[0-9]+.[0-9]+', compute_capability) if not m: - print('Invalid compute capability: ' % compute_capability) + print('Invalid compute capability: %s' % compute_capability) all_valid = False else: - ver = int(m.group(0).split('.')[0]) - if ver < 3: - print('Only compute capabilities 3.0 or higher are supported.') + ver = float(m.group(0)) + if ver < 3.5: + print('ERROR: TensorFlow only supports CUDA compute capabilities 3.5 ' + 'and higher. Please re-specify the list of compute ' + 'capabilities excluding version %s.' % ver) all_valid = False if all_valid: @@ -1491,7 +1493,35 @@ def set_other_mpi_vars(environ_cp): else: raise ValueError( 'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' % - mpi_home, mpi_home, mpi_home) + (mpi_home, mpi_home, mpi_home)) + +def system_specific_test_config(env): + """Add default test flags required for TF tests to bazelrc.""" + write_to_bazelrc('test --flaky_test_attempts=3') + write_to_bazelrc('test --test_size_filters=small,medium') + write_to_bazelrc( + 'test --test_tag_filters=-benchmark-test,-no_oss,-oss_serial') + write_to_bazelrc('test --build_tag_filters=-benchmark-test,-no_oss') + if is_windows(): + if env.get('TF_NEED_CUDA', None) == 1: + write_to_bazelrc( + 'test --test_tag_filters=-no_windows,-no_windows_gpu,-no_gpu') + write_to_bazelrc( + 'test --build_tag_filters=-no_windows,-no_windows_gpu,-no_gpu') + else: + write_to_bazelrc('test --test_tag_filters=-no_windows,-gpu') + write_to_bazelrc('test --build_tag_filters=-no_windows,-gpu') + elif is_macos(): + write_to_bazelrc('test --test_tag_filters=-gpu,-nomac,-no_mac') + write_to_bazelrc('test --build_tag_filters=-gpu,-nomac,-no_mac') + elif is_linux(): + if env.get('TF_NEED_CUDA', None) == 1: + write_to_bazelrc('test --test_tag_filters=-no_gpu') + write_to_bazelrc('test --build_tag_filters=-no_gpu') + write_to_bazelrc('test --test_env=LD_LIBRARY_PATH') + else: + write_to_bazelrc('test --test_tag_filters=-gpu') + write_to_bazelrc('test --build_tag_filters=-gpu') def set_system_libs_flag(environ_cp): @@ -1522,10 +1552,6 @@ def set_windows_build_flags(environ_cp): # The host and target platforms are the same in Windows build. So we don't # have to distinct them. This avoids building the same targets twice. write_to_bazelrc('build --distinct_host_configuration=false') - # Enable short object file path to avoid long path issue on Windows. - # TODO(pcloudy): Remove this flag when upgrading Bazel to 0.16.0 - # Short object file path will be enabled by default. - write_to_bazelrc('build --experimental_shortened_obj_file_path=true') if get_var( environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline', @@ -1546,6 +1572,23 @@ def config_info_line(name, help_text): print('\t--config=%-12s\t# %s' % (name, help_text)) +def configure_apple_bazel_rules(): + """Configures Bazel rules for building on Apple platforms. + + Enables analyzing and building Apple Bazel rules on Apple platforms. This + function will only be executed if `is_macos()` is true. + """ + if not is_macos(): + return + for filepath in APPLE_BAZEL_FILES: + print( + 'Configuring %s file to analyze and build Bazel rules on Apple platforms.' + % filepath) + existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple') + renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath) + os.rename(existing_filepath, renamed_filepath) + + def main(): global _TF_WORKSPACE_ROOT global _TF_BAZELRC @@ -1565,11 +1608,9 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.15.0', '0.20.0') + check_bazel_version('0.19.0', '0.22.0') reset_tf_configure_bazelrc() - # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later - write_to_bazelrc('import %workspace%/tools/bazel.rc') cleanup_makefile() setup_python(environ_cp) @@ -1588,6 +1629,8 @@ def main(): if is_macos(): environ_cp['TF_NEED_TENSORRT'] = '0' + else: + environ_cp['TF_CONFIGURE_APPLE_BAZEL_RULES'] = '0' # The numpy package on ppc64le uses OpenBLAS which has multi-threading # issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at @@ -1690,6 +1733,16 @@ def main(): create_android_ndk_rule(environ_cp) create_android_sdk_rule(environ_cp) + system_specific_test_config(os.environ) + + if get_var( + environ_cp, 'TF_CONFIGURE_APPLE_BAZEL_RULES', + 'Configure Bazel rules for Apple platforms', False, + ('Would you like to configure Bazel rules for building on Apple platforms?' + ), 'Configuring Bazel rules for Apple platforms.', + 'Not configuring Bazel rules for Apple platforms.'): + configure_apple_bazel_rules() + print('Preconfigured Bazel build configs. You can use any of the below by ' 'adding "--config=<>" to your build command. See .bazelrc for more ' 'details.') @@ -1698,14 +1751,15 @@ def main(): config_info_line('gdr', 'Build with GDR support.') config_info_line('verbs', 'Build with libverbs support.') config_info_line('ngraph', 'Build with Intel nGraph support.') - config_info_line('dynamic_kernels', - '(Experimental) Build kernels into separate shared objects.') + config_info_line( + 'dynamic_kernels', + '(Experimental) Build kernels into separate shared objects.') print('Preconfigured Bazel build configs to DISABLE default on features:') config_info_line('noaws', 'Disable AWS S3 filesystem support.') config_info_line('nogcp', 'Disable GCP support.') config_info_line('nohdfs', 'Disable HDFS support.') - config_info_line('noignite', 'Disable Apacha Ignite support.') + config_info_line('noignite', 'Disable Apache Ignite support.') config_info_line('nokafka', 'Disable Apache Kafka support.') config_info_line('nonccl', 'Disable NVIDIA NCCL support.') diff --git a/tensorflow/BUILD b/tensorflow/BUILD index fd4b94202aad24a82abef8abd16431f61a8326f0..f53982f1efc9885cc12dcc672ad819c762aca378 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -40,12 +40,16 @@ load( # @unused TENSORFLOW_API_INIT_FILES_V2 = ( - TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) + TENSORFLOW_API_INIT_FILES + + get_compat_files(TENSORFLOW_API_INIT_FILES, 2) + + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) ) # @unused -TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT = ( - TENSORFLOW_API_INIT_FILES_V1 + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) +TENSORFLOW_API_INIT_FILES_V1 = ( + TENSORFLOW_API_INIT_FILES_V1 + + get_compat_files(TENSORFLOW_API_INIT_FILES, 2) + + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) ) # Config setting used when building for products @@ -90,6 +94,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "emscripten", + values = {"crosstool_top": "//external:android/emscripten"}, + visibility = ["//visibility:public"], +) + config_setting( name = "raspberry_pi_armeabi", values = { @@ -202,6 +212,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "arm", + values = {"cpu": "arm"}, + visibility = ["//visibility:public"], +) + config_setting( name = "freebsd", values = {"cpu": "freebsd"}, @@ -267,6 +283,15 @@ config_setting( visibility = ["//visibility:public"], ) +# By default, XLA GPU is compiled into tensorflow when building with +# --config=cuda even when `with_xla_support` is false. The config setting +# here allows us to override the behavior if needed. +config_setting( + name = "no_xla_deps_in_cuda", + define_values = {"no_xla_deps_in_cuda": "true"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_gdr_support", define_values = {"with_gdr_support": "true"}, @@ -328,6 +353,13 @@ config_setting( }, ) +config_setting( + name = "using_rocm_hipcc", + define_values = { + "using_rocm_hipcc": "true", + }, +) + config_setting( name = "with_mpi_support", values = {"define": "with_mpi_support=true"}, @@ -355,17 +387,18 @@ config_setting( define_values = {"tf_api_version": "2"}, ) +# This flag is defined for select statements that match both +# on 'windows' and 'api_version_2'. In this case, bazel requires +# having a flag which is a superset of these two. +config_setting( + name = "windows_and_api_version_2", + define_values = {"tf_api_version": "2"}, + values = {"cpu": "x64_windows"}, +) + package_group( name = "internal", - packages = [ - "-//third_party/tensorflow/python/estimator", - "//learning/meta_rank/...", - "//tensorflow/...", - "//tensorflow_estimator/contrib/...", - "//tensorflow_fold/llgtm/...", - "//tensorflow_text/...", - "//third_party/py/tensor2tensor/...", - ], + packages = ["//tensorflow/..."], ) load( @@ -429,8 +462,7 @@ tf_cc_shared_object( "//tensorflow:darwin": [], "//tensorflow:windows": [], "//conditions:default": [ - "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow:tf_framework_version_script.lds)", + "-Wl,--version-script,$(location //tensorflow:tf_framework_version_script.lds)", ], }), linkstatic = 1, @@ -464,15 +496,13 @@ tf_cc_shared_object( name = "libtensorflow.so", linkopts = select({ "//tensorflow:darwin": [ - "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "$(location //tensorflow/c:exported_symbols.lds)", + "-Wl,-exported_symbols_list,$(location //tensorflow/c:exported_symbols.lds)", "-Wl,-install_name,@rpath/libtensorflow.so", ], "//tensorflow:windows": [], "//conditions:default": [ "-z defs", - "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow/c:version_script.lds)", + "-Wl,--version-script,$(location //tensorflow/c:version_script.lds)", ], }), visibility = ["//visibility:public"], @@ -490,14 +520,12 @@ tf_cc_shared_object( name = "libtensorflow_cc.so", linkopts = select({ "//tensorflow:darwin": [ - "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "$(location //tensorflow:tf_exported_symbols.lds)", + "-Wl,-exported_symbols_list,$(location //tensorflow:tf_exported_symbols.lds)", ], "//tensorflow:windows": [], "//conditions:default": [ "-z defs", - "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow:tf_version_script.lds)", + "-Wl,--version-script,$(location //tensorflow:tf_version_script.lds)", ], }), visibility = ["//visibility:public"], @@ -574,13 +602,20 @@ gen_api_init_files( name = "tf_python_api_gen_v1", srcs = [ "api_template_v1.__init__.py", + "compat_template.__init__.py", "compat_template_v1.__init__.py", ], api_version = 1, - compat_api_versions = [1], - compat_init_templates = ["compat_template_v1.__init__.py"], + compat_api_versions = [ + 1, + 2, + ], + compat_init_templates = [ + "compat_template_v1.__init__.py", + "compat_template.__init__.py", + ], output_dir = "_api/v1/", - output_files = TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT, + output_files = TENSORFLOW_API_INIT_FILES_V1, output_package = "tensorflow._api.v1", root_file_name = "v1.py", root_init_template = "api_template_v1.__init__.py", @@ -590,11 +625,18 @@ gen_api_init_files( name = "tf_python_api_gen_v2", srcs = [ "api_template.__init__.py", + "compat_template.__init__.py", "compat_template_v1.__init__.py", ], api_version = 2, - compat_api_versions = [1], - compat_init_templates = ["compat_template_v1.__init__.py"], + compat_api_versions = [ + 1, + 2, + ], + compat_init_templates = [ + "compat_template_v1.__init__.py", + "compat_template.__init__.py", + ], output_dir = "_api/v2/", output_files = TENSORFLOW_API_INIT_FILES_V2, output_package = "tensorflow._api.v2", @@ -606,9 +648,11 @@ py_library( name = "tensorflow_py", srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [ + deps = select({ + "api_version_2": [], + "//conditions:default": ["//tensorflow/contrib:contrib_py"], + }) + [ ":tensorflow_py_no_contrib", - "//tensorflow/contrib:contrib_py", "//tensorflow/python/estimator:estimator_py", ], ) @@ -618,7 +662,11 @@ py_library( srcs = select({ "api_version_2": [":tf_python_api_gen_v2"], "//conditions:default": [":tf_python_api_gen_v1"], - }) + [":root_init_gen"], + }) + [":root_init_gen"] + [ + "//tensorflow/python/keras/api:keras_python_api_gen", + "//tensorflow/python/keras/api:keras_python_api_gen_compat_v1", + "//tensorflow/python/keras/api:keras_python_api_gen_compat_v2", + ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = ["//tensorflow/python:no_contrib"], diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index d81cf067eb07e88e2b8a86cf5643674235eb3f3b..ddcacfcbe2d4d8b089f10f1a771384dc8c4fd199 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -18,27 +18,84 @@ from __future__ import absolute_import as _absolute_import from __future__ import division as _division from __future__ import print_function as _print_function +import distutils as _distutils +import inspect as _inspect import os as _os - -# pylint: disable=g-bad-import-order -from tensorflow.python.tools import component_api_helper as _component_api_helper -_component_api_helper.package_hook( - parent_package_str=__name__, - child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) +import site as _site +import sys as _sys # API IMPORTS PLACEHOLDER # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. # We're using bitwise, but there's nothing special about that. -_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable -if _tf_api_dir not in __path__: +_API_MODULE = bitwise # pylint: disable=undefined-variable +_current_module = _sys.modules[__name__] +_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) +if not hasattr(_current_module, '__path__'): + __path__ = [_tf_api_dir] +elif _tf_api_dir not in __path__: __path__.append(_tf_api_dir) +# pylint: disable=g-bad-import-order +from tensorflow.python.tools import component_api_helper as _component_api_helper +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorboard.summary._tf.summary'), + error_msg="Limited tf.summary API due to missing TensorBoard installation") +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=( + 'tensorflow_estimator.python.estimator.api._v2.estimator')) + +if not hasattr(_current_module, 'estimator'): + _component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=( + 'tensorflow_estimator.python.estimator.api.estimator')) +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow.python.keras.api._v2.keras')) + # Enable TF2 behaviors -from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top +from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top _compat.enable_v2_behavior() + +# Load all plugin libraries from site-packages/tensorflow-plugins if we are +# running under pip. +# TODO(gunan): Enable setting an environment variable to define arbitrary plugin +# directories. +# TODO(gunan): Find a better location for this code snippet. +from tensorflow.python.framework import load_library as _ll +from tensorflow.python.lib.io import file_io as _fi + +# Get sitepackages directories for the python installation. +_site_packages_dirs = [] +_site_packages_dirs += [_site.USER_SITE] +_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p] +if 'getsitepackages' in dir(_site): + _site_packages_dirs += _site.getsitepackages() + +if 'sysconfig' in dir(_distutils): + _site_packages_dirs += [_distutils.sysconfig.get_python_lib()] + +_site_packages_dirs = list(set(_site_packages_dirs)) + +# Find the location of this exact file. +_current_file_location = _inspect.getfile(_inspect.currentframe()) + +def _running_from_pip_package(): + return any( + _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) + +if _running_from_pip_package(): + for s in _site_packages_dirs: + # TODO(gunan): Add sanity checks to loaded modules here. + plugin_dir = _os.path.join(s, 'tensorflow-plugins') + if _fi.file_exists(plugin_dir): + _ll.load_library(plugin_dir) + # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They # must come from this module. So python adds these symbols for the @@ -59,4 +116,11 @@ try: del compiler except NameError: pass + +# Add module aliases +if hasattr(_current_module, 'keras'): + losses = keras.losses + metrics = keras.metrics + optimizers = keras.optimizers + # pylint: enable=undefined-variable diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 65bdb6cb1b5e6fb0656a12b932d767aeacfccd29..5eb25a81b7f765f551bc4f1b7ba99b35dbc6b7bb 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -18,20 +18,42 @@ from __future__ import absolute_import as _absolute_import from __future__ import division as _division from __future__ import print_function as _print_function +import distutils as _distutils +import inspect as _inspect import os as _os +import site as _site +import sys as _sys # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, - child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) - -# API IMPORTS PLACEHOLDER + child_package_str=( + 'tensorflow_estimator.python.estimator.api._v1.estimator')) +_current_module = _sys.modules[__name__] +if not hasattr(_current_module, 'estimator'): + _component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=( + 'tensorflow_estimator.python.estimator.api.estimator')) +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow.python.keras.api._v1.keras')) from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top -contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') +_CONTRIB_WARNING = """ +WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0. +For more information, please see: + * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md + * https://github.com/tensorflow/addons +If you depend on functionality not listed there, please file an issue. +""" +contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib', + _CONTRIB_WARNING) del LazyLoader # The templated code that replaces the placeholder above sometimes # sets the __all__ variable. If it does, we have to be sure to add @@ -40,14 +62,53 @@ if '__all__' in vars(): vars()['__all__'].append('contrib') from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +# The 'app' module will be imported as part of the placeholder section above. app.flags = flags # pylint: disable=undefined-variable +# Also use 'app' module (choice is arbitrary) to derive the API directory below. +_API_MODULE = app # pylint: disable=undefined-variable + # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. -_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable -if _tf_api_dir not in __path__: +_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) +if not hasattr(_current_module, '__path__'): + __path__ = [_tf_api_dir] +elif _tf_api_dir not in __path__: __path__.append(_tf_api_dir) +# Load all plugin libraries from site-packages/tensorflow-plugins if we are +# running under pip. +# TODO(gunan): Enable setting an environment variable to define arbitrary plugin +# directories. +# TODO(gunan): Find a better location for this code snippet. +from tensorflow.python.framework import load_library as _ll +from tensorflow.python.lib.io import file_io as _fi + +# Get sitepackages directories for the python installation. +_site_packages_dirs = [] +_site_packages_dirs += [_site.USER_SITE] +_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p] +if 'getsitepackages' in dir(_site): + _site_packages_dirs += _site.getsitepackages() + +if 'sysconfig' in dir(_distutils): + _site_packages_dirs += [_distutils.sysconfig.get_python_lib()] + +_site_packages_dirs = list(set(_site_packages_dirs)) + +# Find the location of this exact file. +_current_file_location = _inspect.getfile(_inspect.currentframe()) + +def _running_from_pip_package(): + return any( + _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) + +if _running_from_pip_package(): + for s in _site_packages_dirs: + # TODO(gunan): Add sanity checks to loaded modules here. + plugin_dir = _os.path.join(s, 'tensorflow-plugins') + if _fi.file_exists(plugin_dir): + _ll.load_library(plugin_dir) # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 25df970ecab0757f23465ab19e7f45de0c759458..ef7863dc0d5cbd57da30baa6e04278c2a0354b25 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -67,6 +67,23 @@ tf_cuda_library( tf_cuda_library( name = "c_api", + hdrs = ["c_api.h"], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":c_api_no_xla", + ":c_api_internal", + ] + select({ + "//tensorflow:with_xla_support": [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/jit", + ], + "//conditions:default": [], + }), +) + +tf_cuda_library( + name = "c_api_no_xla", srcs = [ "c_api.cc", "c_api_function.cc", @@ -75,15 +92,13 @@ tf_cuda_library( "c_api.h", ], copts = tf_copts(), - visibility = ["//visibility:public"], - deps = select({ + visibility = ["//tensorflow/c:__subpackages__"], + deps = [":c_api_internal"] + select({ "//tensorflow:android": [ - ":c_api_internal", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ - ":c_api_internal", - "//tensorflow/cc/saved_model:loader", + "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc:gradients", "//tensorflow/cc:ops", "//tensorflow/cc:grad_ops", @@ -97,13 +112,8 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/kernels:logging_ops", ], - }) + select({ - "//tensorflow:with_xla_support": [ - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/jit", - ], - "//conditions:default": [], }), ) @@ -123,13 +133,13 @@ tf_cuda_library( "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_internal", "//tensorflow/compiler/jit:flags", - "//tensorflow/contrib/tpu:all_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_platform", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:attr_builder", + "@com_google_absl//absl/strings", ], ) @@ -156,8 +166,8 @@ tf_cuda_library( hdrs = ["tf_status_helper.h"], visibility = ["//visibility:public"], deps = [ - ":c_api", ":c_api_internal", + ":c_api_no_xla", "//tensorflow/core:lib", ], ) @@ -190,14 +200,12 @@ tf_cuda_library( ":c_api", ":tf_status_helper", "//tensorflow/core:android_tensorflow_lib_lite", - "//tensorflow/core:platform_env", "//tensorflow/core:lib", ], "//conditions:default": [ ":c_api", ":tf_status_helper", "//tensorflow/core:framework", - "//tensorflow/core:platform_env", "//tensorflow/core:lib", ], }) + [":c_api_internal"], @@ -215,13 +223,13 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - ":c_api", + ":c_api_no_xla", ":c_api_internal", ":tf_status_helper", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ - ":c_api", + ":c_api_no_xla", ":c_api_internal", ":tf_status_helper", "//tensorflow/core:framework", @@ -251,6 +259,18 @@ tf_cuda_library( ], ) +tf_cc_test( + name = "c_test", + srcs = ["c_test.c"], + extra_copts = ["-std=c11"], + deps = [ + ":c_api", + ":c_api_experimental", + ":env", + ":kernels", + ], +) + tf_cuda_cc_test( name = "c_api_test", size = "small", @@ -279,13 +299,23 @@ tf_cuda_cc_test( "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/cc/saved_model:tag_constants", "//tensorflow/compiler/jit", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:bitwise_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:direct_session", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:spectral_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/kernels:array", @@ -309,6 +339,7 @@ tf_cc_test( deps = [ ":c_api", ":c_api_experimental", + ":c_api_internal", ":c_test_util", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_test_util", @@ -325,6 +356,7 @@ tf_cc_test( srcs = ["c_api_function_test.cc"], deps = [ ":c_api", + ":c_api_internal", ":c_test_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 94d18eb8b04e3534be547aca5cfbb32da40ffbf6..245d7ba2b186895532953aa61ebfc3fc6bf635a7 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/cc/ops/while_loop.h" #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/kernels/logging_ops.h" #endif #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -257,6 +258,74 @@ int64_t TF_Dim(const TF_Tensor* t, int dim_index) { size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); } void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); } +int64_t TF_TensorElementCount(const TF_Tensor* t) { + int64_t result = 1; + int rank = TF_NumDims(t); + for (int dim = 0; dim < rank; ++dim) { + result *= TF_Dim(t, dim); + } + return result; +} + +// Returns the number of elements that would be present in a tensor with the +// given shape. +static int64_t ShapeNumElements(const int64_t* dims, int num_dims) { + int64_t result = 1; + for (int dim = 0; dim < num_dims; ++dim) { + result *= dims[dim]; + } + return result; +} + +static void UnrefIfNonNull(::tensorflow::TensorBuffer* buf) { + if (buf != nullptr) { + buf->Unref(); + } +} + +static void RefIfNonNull(::tensorflow::TensorBuffer* buf) { + if (buf != nullptr) { + buf->Ref(); + } +} + +void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type, + TF_Tensor* to, const int64_t* new_dims, + int num_new_dims, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + size_t in_size = TF_DataTypeSize(TF_TensorType(from)); + if (in_size == 0) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "input tensor has a zero-sized data type"); + return; + } + size_t out_size = TF_DataTypeSize(type); + if (out_size == 0) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "output tensor has a zero-sized data type"); + return; + } + + if (ShapeNumElements(new_dims, num_new_dims) * out_size != + TF_TensorElementCount(from) * in_size) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "input tensor is not compatible with output shape"); + return; + } + + tensorflow::TensorShapeProto p; + for (int i = 0; i < num_new_dims; ++i) { + p.add_dim()->set_size(new_dims[i]); + } + to->shape = tensorflow::TensorShape(p); + to->dtype = type; + if (to->buffer != from->buffer) { + UnrefIfNonNull(to->buffer); + to->buffer = from->buffer; + RefIfNonNull(to->buffer); + } +} + // -------------------------------------------------------------------------- size_t TF_StringEncode(const char* src, size_t src_len, char* dst, size_t dst_len, TF_Status* status) { @@ -488,6 +557,7 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { // Non-static for testing. TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); if (!src.IsInitialized()) { status->status = FailedPrecondition( "attempt to use a tensor with an uninitialized value"); @@ -571,7 +641,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, dimvec.size(), base, size, DeleteArray, base); } -Status MessageToBuffer(const tensorflow::protobuf::Message& in, +Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, TF_Buffer* out) { if (out->data != nullptr) { return InvalidArgument("Passing non-empty TF_Buffer is invalid."); @@ -1241,6 +1311,13 @@ void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, reinterpret_cast(values), num_values)); } +void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name, + const char* placeholder) { + tensorflow::AttrValue attr_value; + attr_value.set_placeholder(placeholder); + desc->node_builder.Attr(attr_name, attr_value); +} + void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, const char* value, size_t length) { tensorflow::NameAttrList func_name; @@ -2880,6 +2957,16 @@ const char* TF_ServerTarget(TF_Server* server) { #endif } -void TF_DeleteServer(TF_Server* server) { delete server; } +void TF_DeleteServer(TF_Server* server) { +#ifndef __ANDROID__ + delete server; +#endif +} + +void TF_RegisterLogListener(void (*listener)(const char*)) { +#ifndef __ANDROID__ + tensorflow::logging::RegisterListener(listener); +#endif +} } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index c7abba85521fccec07983cd5ab4f94a8368d6181..051de3a7dc0f8c630b6c81d2cfa960e5279c93c0 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -272,6 +272,39 @@ TF_CAPI_EXPORT extern size_t TF_TensorByteSize(const TF_Tensor*); // Return a pointer to the underlying data buffer. TF_CAPI_EXPORT extern void* TF_TensorData(const TF_Tensor*); +// Returns the number of elements in the tensor. +TF_CAPI_EXPORT extern int64_t TF_TensorElementCount(const TF_Tensor* tensor); + +// Copy the internal data representation of `from` to `to`. `new_dims` and +// `num_new_dims` specify the new shape of the `to` tensor, `type` specifies its +// data type. On success, *status is set to TF_OK and the two tensors share the +// same data buffer. +// +// This call requires that the `from` tensor and the given type and shape (dims +// and num_dims) are "compatible" (i.e. they occupy the same number of bytes). +// Specifically, given from_type_size = TF_DataTypeSize(TF_TensorType(from)): +// +// ShapeElementCount(dims, num_dims) * TF_DataTypeSize(type) +// +// must equal +// +// TF_TensorElementCount(from) * from_type_size +// +// where TF_ShapeElementCount would be the number of elements in a tensor with +// the given shape. +// +// In addition, this function requires: +// * TF_DataTypeSize(TF_TensorType(from)) != 0 +// * TF_DataTypeSize(type) != 0 +// +// If any of the requirements are not met, *status is set to +// TF_INVALID_ARGUMENT. +TF_CAPI_EXPORT extern void TF_TensorBitcastFrom(const TF_Tensor* from, + TF_DataType type, TF_Tensor* to, + const int64_t* new_dims, + int num_new_dims, + TF_Status* status); + // -------------------------------------------------------------------------- // Encode the string `src` (`src_len` bytes long) into `dst` in the format // required by TF_STRING tensors. Does not write to memory more than `dst_len` @@ -516,6 +549,10 @@ TF_CAPI_EXPORT extern void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, const TF_DataType* values, int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrPlaceholder(TF_OperationDescription* desc, + const char* attr_name, + const char* placeholder); + // Set a 'func' attribute to the specified name. // `value` must point to a string of length `length` bytes. TF_CAPI_EXPORT extern void TF_SetAttrFuncName(TF_OperationDescription* desc, @@ -1277,6 +1314,28 @@ TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( int noutputs, const TF_Output* outputs, const char* const* output_names, const TF_FunctionOptions* opts, const char* description, TF_Status* status); +// Similar to TF_GraphToFunction but allows specifying control outputs of the +// function. +// +// The arguments of TF_GraphToFunction have the same meaning, but the new +// arguments are as follows: +// +// ncontrol_outputs: Number of control outputs of the function. +// control_outputs: vector of TF_Operation objects to be marked as control +// outputs of the function. Operations marked as control outputs are +// guaranteed to execute. +// control_output_names: Optional. If not nullptr, vector of strings, one +// per control output, with their names to be added to the function's +// OpDef. +TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunctionWithControlOutputs( + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + int ncontrol_outputs, const TF_Operation* const* control_outputs, + const char* const* control_output_names, const TF_FunctionOptions* opts, + const char* description, TF_Status* status); + // Returns the name of the graph function. // The return value points to memory that is only usable until the next // mutation to *func. @@ -1710,6 +1769,14 @@ TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server); // it will be stopped and joined. TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); +// Register a listener method that processes printed messages. +// +// If any listeners are registered, the print operator will call all listeners +// with the printed messages and immediately return without writing to the +// logs. +TF_CAPI_EXPORT extern void TF_RegisterLogListener( + void (*listener)(const char*)); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 38e29aa74a90f4e85d1369b6928a5a58c531b2da..7ff4084decc686b067226ecaecf2af29d51d42f2 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" +#include "absl/strings/substitute.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api.h" @@ -66,7 +67,8 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { } TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, - unsigned char gpu_memory_allow_growth) { + unsigned char gpu_memory_allow_growth, + unsigned int num_cpu_devices) { tensorflow::ConfigProto config; auto* optimizer_options = config.mutable_graph_options()->mutable_optimizer_options(); @@ -87,6 +89,8 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, auto* gpu_options = config.mutable_gpu_options(); gpu_options->set_allow_growth(gpu_memory_allow_growth); + (*config.mutable_device_count())["CPU"] = num_cpu_devices; + // TODO(b/113217601): This is needed for EagerContext::runner_ to use a // threadpool, so that we avoid the possibility of running the runner_ in the // threadpool of GPU event mgr, as that can trigger more callbacks to be @@ -125,6 +129,14 @@ const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { return ret; } +char* TF_FunctionDebugString(TF_Function* func, size_t* len) { + const auto& debug_str = func->fdef.DebugString(); + *len = debug_str.size(); + char* ret = static_cast(malloc(*len + 1)); + memcpy(ret, debug_str.c_str(), *len + 1); + return ret; +} + // On success, returns a set of TF_Function instances from `text_proto` of // GraphDef type. These functions must be deleted by calling TF_DeleteFunction. // @@ -8535,8 +8547,9 @@ TFE_Context* TFE_CreateContextFromSession(TF_Session* session, // Reduce GPU memory allocation, and set appropriate config options for TFE // context. - auto* config = - TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true); + auto* config = TF_CreateConfig( + /*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ + 10); TFE_ContextOptionsSetConfig(opts, config->data, config->length, status); if (!status->status.ok()) { CHECK(!config); @@ -8733,6 +8746,12 @@ static void CheckOk(TF_Status* status) { void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { auto* status = TF_NewStatus(); + if (!TFE_TensorHandleIsConcrete(handle)) { + VLOG(1) << "Symbolic tensor: " << handle; + TF_DeleteStatus(status); + return; + } + TF_Tensor* t = TFE_TensorHandleResolve(handle, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -8744,6 +8763,11 @@ void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { TF_DeleteStatus(status); } +void TFE_OpPrintDebugString(TFE_Op* op) { + VLOG(1) << "TFE_OpPrintDebugString() over " << op; + LOG(INFO) << op->operation.DebugString(); +} + struct TFE_ExecuteOpNotification { TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {} tensorflow::Notification n; @@ -8886,3 +8910,240 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType dtype_arg, std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len); return new TFE_TensorHandle(tensor, nullptr, nullptr); } + +namespace { +tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, + TFE_Context* ctx) { + // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the + // server object (which currently CHECK-fails) and we miss the error, instead, + // we log the error, and then return to allow the user to see the error + // message. +#define LOG_AND_RETURN_IF_ERROR(...) \ + do { \ + const ::tensorflow::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + LOG(ERROR) << _status.error_message(); \ + return _status; \ + } \ + } while (0); + + std::unique_ptr server; + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server)); + + tensorflow::GrpcServer* grpc_server = + dynamic_cast(server.get()); + if (grpc_server == nullptr) { + LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal( + "Currently, TFE_NewContext only supports tensorflow::GrpcServer.")); + } + + LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); + + LOG_AND_RETURN_IF_ERROR(ctx->context.StoreCollectiveOpsServer( + std::move(server), grpc_server->worker_env()->device_mgr, + grpc_server->worker_env()->collective_executor_mgr)); + + return tensorflow::Status::OK(); +#undef LOG_AND_RETURN_IF_ERROR +} +} // namespace + +// Set server_def on the context, possibly updating it. +TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, + const void* proto, + size_t proto_len, + TF_Status* status) { + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + return; + } + status->status = EnableCollectiveOps(server_def, ctx); +} + +std::string tensorflow::getTF_OutputDebugString(TF_Output node) { + return absl::Substitute("TF_Output($0, $1)", node.oper, node.index); +} + +using tensorflow::getTF_OutputDebugString; + +TFE_TensorHandle* TFE_NewTensorHandleFromTFOutput(TF_Output t, + TF_DataType dtype) { + auto ret = new TFE_TensorHandle(t, dtype); + VLOG(1) << "Storing TFOutput " << getTF_OutputDebugString(t) + << " into tensor handle " << ret << " with internal handle " + << ret->handle; + return ret; +} + +unsigned char TFE_TensorHandleIsConcrete(TFE_TensorHandle* handle) { + assert(handle->handle != nullptr); + return handle->handle->getSymbolicTensor() == nullptr; +} + +TF_Output TFE_GetTFOutputFromTensorHandle(TFE_TensorHandle* handle, + TF_Status* status) { + if (TFE_TensorHandleIsConcrete(handle)) { + status->status = + tensorflow::errors::Internal("Not a symbolic tensor: ", handle); + return TF_Output{nullptr, -1}; + } + + auto* sym_tensor = handle->handle->getSymbolicTensor(); + CHECK(sym_tensor != nullptr); + auto ret = TF_Output{sym_tensor->oper, sym_tensor->index}; + VLOG(1) << "Retrieving " << getTF_OutputDebugString(ret) + << " from tensor handle " << handle; + CHECK_GE(sym_tensor->index, 0); + return ret; +} + +TFE_TraceContext* TFE_NewTraceContext(TF_Graph* graph) { + return new TFE_TraceContext(graph); +} + +void TFE_DeleteTraceContext(TFE_TraceContext* trace_ctx) { delete trace_ctx; } + +// If `handle` is already symbolic, return it. Otherwise map it to a new +// symbolic tensor (a PlaceHolder op) and return that. +static TF_Output getOrCreateSymbolicTensor(TFE_TraceContext* trace_ctx, + tensorflow::TensorHandle* handle, + TF_Status* status) { + VLOG(1) << "Getting symbolic tensor for input tensor handle " << handle + << ": " << handle->DebugString(); + + auto* sym_tensor = handle->getSymbolicTensor(); + if (sym_tensor != nullptr) { + auto ret = TF_Output{sym_tensor->oper, sym_tensor->index}; + VLOG(1) << "This handle is a symbolic tensor " << sym_tensor << ": " + << getTF_OutputDebugString(ret); + return ret; + } + + auto find_it = trace_ctx->input_tensor_map.find(handle); + if (find_it != trace_ctx->input_tensor_map.end()) { + VLOG(1) << "There exists a map entry from this concrete tensor to: " + << getTF_OutputDebugString(find_it->second); + return find_it->second; + } + + auto node_name = tensorflow::strings::StrCat("additional_input_", + trace_ctx->node_counter++); + VLOG(1) << "Adding a place holder node named " << node_name; + auto* desc = + TF_NewOperation(trace_ctx->graph, "Placeholder", node_name.c_str()); + TF_SetAttrType(desc, "dtype", + static_cast(handle->dtype) /*TF_FLOAT*/); + auto* result = TF_FinishOperation(desc, status); + if (!status->status.ok()) { + return TF_Output{nullptr, -1}; + } + + auto ret = TF_Output{result, 0}; + VLOG(1) << "Creating a new map entry to map to: " + << getTF_OutputDebugString(ret); + trace_ctx->input_tensor_map[handle] = ret; + // `handle` could be destroyed before it's read from `input_tensor_map` (say + // during a subsequent TFE_FinalizeInputTensorsFromTraceContext() call), so we + // increment its ref count to extend its life span to that of `trace_ctx`. + handle->Ref(); + VLOG(1) << "Ref count for handle " << handle + << " is 1?: " << handle->RefCountIsOne(); + return ret; +} + +TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, + TFE_TensorHandle** retvals, + int* num_retvals, TF_Status* status) { + VLOG(1) << "Calling TFE_AddEagerOpToGraph() with op " << op << ": " + << op->operation.DebugString(); + + const auto& op_type = op->operation.Name(); + auto op_name = + tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++); + auto* desc = + TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()); + + VLOG(1) << "Adding attrs."; + tensorflow::AttrValueMap attrs; + op->operation.Attrs().FillAttrValueMap(&attrs); + for (const auto& attr : attrs) { + desc->node_builder.Attr(attr.first, attr.second); + } + + VLOG(1) << "Adding inputs."; + const auto& inputs = op->operation.Inputs(); + size_t inputIndex = 0; + const tensorflow::OpDef& op_def = desc->node_builder.op_def(); + for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) { + // TODO(bgogul): Add support for number attributes. + DCHECK(input_arg.number_attr().empty()) + << "Number attributes is not implemented yet."; + if (input_arg.type_list_attr().empty()) { + auto symbolic_input = + getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status); + if (!status->status.ok()) return nullptr; + TF_AddInput(desc, symbolic_input); + continue; + } + const std::string& type_list_attr = input_arg.type_list_attr(); + const auto& attr_value = attrs[type_list_attr]; + DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList) + << "Type list attribute should be a list!"; + std::vector list_inputs(attr_value.list().type_size()); + for (TF_Output& list_input : list_inputs) { + list_input = + getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status); + if (!status->status.ok()) return nullptr; + } + TF_AddInputList(desc, list_inputs.data(), list_inputs.size()); + } + + auto* graph_op = TF_FinishOperation(desc, status); + if (!status->status.ok()) return nullptr; + + VLOG(1) << "Op finalized; setting return tensors."; + *num_retvals = TF_OperationNumOutputs(graph_op); + VLOG(1) << "This op has " << *num_retvals << " outputs."; + for (int i = 0; i < *num_retvals; ++i) { + auto output = TF_Output{graph_op, i}; + auto dtype = TF_OperationOutputType(output); + retvals[i] = TFE_NewTensorHandleFromTFOutput(output, dtype); + } + return graph_op; +} + +int TFE_FinalizeInputTensorsFromTraceContext(TFE_TraceContext* trace_ctx) { + if (trace_ctx->input_tensors == nullptr) { + trace_ctx->input_tensors = + new std::vector>(); + trace_ctx->input_tensors->reserve(trace_ctx->input_tensor_map.size()); + + for (auto input : trace_ctx->input_tensor_map) { + trace_ctx->input_tensors->emplace_back(input.first, input.second); + } + } + return trace_ctx->input_tensor_map.size(); +} + +TF_Output TFE_GetInputGraphNodeFromTraceContext(TFE_TraceContext* trace_ctx, + unsigned int idx) { + CHECK(trace_ctx->input_tensors != nullptr); + CHECK(trace_ctx->input_tensors->size() > idx); + return trace_ctx->input_tensors->at(idx).second; +} + +TFE_TensorHandle* TFE_ConsumeInputConcreteTensorFromTraceContext( + TFE_TraceContext* trace_ctx, unsigned int idx) { + CHECK(trace_ctx->input_tensors != nullptr); + CHECK(trace_ctx->input_tensors->size() > idx); + auto* handle = trace_ctx->input_tensors->at(idx).first; + VLOG(1) << "Ref count for internal handle " << handle + << " is 1?: " << handle->RefCountIsOne(); + handle->Ref(); + auto* ret = new TFE_TensorHandle(handle); + VLOG(1) << "Returning a new tensor handle " << ret << ": " + << handle->DebugString(); + return ret; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 3e3a485eb763b871b0551414c4ef04746b2ed9a3..8d1a8b82fbaf9901b6d9aecf6d092ae298c8dba3 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -67,9 +67,10 @@ TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, // a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if // `enable_xla_compilation` is non-zero, and OFF otherwise. // b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`. +// c) ConfigProto.device_count is set to `num_cpu_devices`. TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig( - unsigned char enable_xla_compilation, - unsigned char gpu_memory_allow_growth); + unsigned char enable_xla_compilation, unsigned char gpu_memory_allow_growth, + unsigned int num_cpu_devices); // Create a serialized tensorflow.RunOptions proto, where RunOptions.trace_level // is set to FULL_TRACE if `enable_full_trace` is non-zero, and NO_TRACE @@ -83,6 +84,15 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_CreateRunOptions( TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph, size_t* len); +// Returns the function content in a human-readable format, with length set in +// `len`. The format is subject to change in the future. +// The returned string is heap-allocated, and caller should call free() on it. +// +// Do not return const char*, because some foreign language binding +// (e.g. swift) cannot then call free() on the returned pointer. +TF_CAPI_EXPORT extern char* TF_FunctionDebugString(TF_Function* func, + size_t* len); + // Creates a stack of data set + iterator nodes, currently hard-coded to return // a sequence of 3 float values <42.0, 43.0, 44.0> over 3 calls. On success, // returns the IteratorGetNext node, which caller can run or feed into an node. @@ -180,6 +190,8 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString( TFE_TensorHandle* handle); +TF_CAPI_EXPORT extern void TFE_OpPrintDebugString(TFE_Op* op); + typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification; // Allows invoking a kernel asynchronously, and explicitly returns a @@ -246,6 +258,62 @@ TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(void); TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromScalar( TF_DataType dtype, void* scalar, size_t len); +// Specify the server_def that enables collective ops. +// This is different to the above function in that it doesn't create remote +// contexts, and remotely executing ops is not possible. It just enables +// communication for collective ops. +TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, + const void* proto, + size_t proto_len, + TF_Status* status); + +// Create a symbolic tensor from the input graph node. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromTFOutput( + TF_Output t, TF_DataType data_type); + +// Returns 0 if the input tensor handle represents a symbolic tensor (i.e., a +// graph node). Otherwise returns non-0. +TF_CAPI_EXPORT extern unsigned char TFE_TensorHandleIsConcrete( + TFE_TensorHandle* handle); + +// If `handle` is a symbolic tensor, return the corresponding graph node +// represented by TF_Output. Otherwise, return an error status. +TF_CAPI_EXPORT extern TF_Output TFE_GetTFOutputFromTensorHandle( + TFE_TensorHandle* handle, TF_Status* status); + +typedef struct TFE_TraceContext TFE_TraceContext; + +// A trace context contains a trace graph, to which TFE_AddEagerOpToGraph() +// calls add graph nodes as a way to symbolically execute the eager ops. +// +// It also contains a hash map from concrete input tensors to symbolic +// tensors. That map will be used to create input tensors to the trace graph. +TF_CAPI_EXPORT extern TFE_TraceContext* TFE_NewTraceContext(TF_Graph* graph); + +TF_CAPI_EXPORT extern void TFE_DeleteTraceContext(TFE_TraceContext* trace_ctx); + +// Symbolically executes `op`, by adding a corresponding node to the graph +// associated with `trace_ctx`. This graph node outputs a set of symbolic +// tensors in `retvals` and `num_retvals`. Returns the corresponding graph +// operation on success, otherwise returns nullptr. +TF_CAPI_EXPORT extern TF_Operation* TFE_AddEagerOpToGraph( + TFE_Op* op, TFE_TraceContext* trace_ctx, TFE_TensorHandle** retvals, + int* num_retvals, TF_Status* status); + +// Finalizes the trace graph and its inputs, and returns the number of inputs. +// After this call, the next two APIs can be called to iterate over the input +// tensors. +TF_CAPI_EXPORT extern int TFE_FinalizeInputTensorsFromTraceContext( + TFE_TraceContext* trace_ctx); + +TF_CAPI_EXPORT extern TF_Output TFE_GetInputGraphNodeFromTraceContext( + TFE_TraceContext* trace_ctx, unsigned int idx); + +// Each input tensor should be consumed at most once. +TF_CAPI_EXPORT extern TFE_TensorHandle* +TFE_ConsumeInputConcreteTensorFromTraceContext(TFE_TraceContext* trace_ctx, + unsigned int idx); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index daa7701b7fe7e8ce757b6504329cf6434ad39778..c54021a7517ebbdd00405cbfa9cee8f3f6616cca 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_test_util.h" @@ -296,5 +297,178 @@ TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) { TF_DeleteStatus(status); } +TEST(CAPI_EXPERIMENTAL, SymbolicTensor) { + TF_Status* status = TF_NewStatus(); + auto node = TF_Output{nullptr, 1}; + auto* sym_handle = TFE_NewTensorHandleFromTFOutput(node, TF_FLOAT); + TFE_TensorHandlePrintDebugString(sym_handle); + CHECK_EQ(TFE_TensorHandleDataType(sym_handle), TF_FLOAT); + ASSERT_FALSE(TFE_TensorHandleIsConcrete(sym_handle)); + auto same_node = TFE_GetTFOutputFromTensorHandle(sym_handle, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(same_node.oper, node.oper); + ASSERT_EQ(same_node.index, node.index); + TFE_DeleteTensorHandle(sym_handle); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + ASSERT_TRUE(TFE_TensorHandleIsConcrete(m)); + (void)TFE_GetTFOutputFromTensorHandle(m, status); + CHECK_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(m); + + TF_DeleteStatus(status); +} + +class AddEagerOpToGraphTest : public ::testing::Test { + protected: + AddEagerOpToGraphTest() + : status_(TF_NewStatus()), + eager_ctx_(nullptr), + graph_(TF_NewGraph()), + trace_ctx_(TFE_NewTraceContext(graph_)) { + TFE_ContextOptions* opts = TFE_NewContextOptions(); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + eager_ctx_ = TFE_NewContext(opts, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_DeleteContextOptions(opts); + } + + ~AddEagerOpToGraphTest() override { + TFE_DeleteTraceContext(trace_ctx_); + TF_DeleteGraph(graph_); + TFE_DeleteContext(eager_ctx_); + TF_DeleteStatus(status_); + } + + template + void AddEagerOpToGraphAndCheck(TFE_Op* op, Callable checker) { + TFE_TensorHandle* retvals[5]; + int num_retvals = 5; + // Symbolically execute this op, which adds a graph node to `trace_ctx_`. + TF_Operation* graph_op = + TFE_AddEagerOpToGraph(op, trace_ctx_, retvals, &num_retvals, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_NOTNULL(graph_op); + // Check the expectations. + checker(graph_op); + for (int i = 0; i < num_retvals; ++i) { + TFE_DeleteTensorHandle(retvals[i]); + } + } + + TF_Status* status_; + TFE_Context* eager_ctx_; + TF_Graph* graph_; + TFE_TraceContext* trace_ctx_; +}; + +TEST_F(AddEagerOpToGraphTest, DebugPrintAndSymbolicExecution) { + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* op = MatMulOp(eager_ctx_, m, m); + + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpPrintDebugString(op); + + TFE_TensorHandle* retvals[5]; + int num_retvals = 5; + // Symbolically execute this op, which adds a graph node to `trace_ctx`. + TFE_AddEagerOpToGraph(op, trace_ctx_, retvals, &num_retvals, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + + int num_inputs = TFE_FinalizeInputTensorsFromTraceContext(trace_ctx_); + CHECK_EQ(num_inputs, 1); + auto input_sym_tensor = TFE_GetInputGraphNodeFromTraceContext(trace_ctx_, + /*idx*/ 0); + + LOG(INFO) << tensorflow::getTF_OutputDebugString(input_sym_tensor); + auto handle = TFE_ConsumeInputConcreteTensorFromTraceContext(trace_ctx_, + /*idx*/ 0); + TFE_TensorHandlePrintDebugString(handle); + TFE_DeleteTensorHandle(handle); + + CHECK_EQ(num_retvals, 1); + CHECK_EQ(TFE_TensorHandleDataType(retvals[0]), TF_FLOAT); + + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(m); + TFE_DeleteOp(op); +} + +TEST_F(AddEagerOpToGraphTest, ValueAttributesArePreserved) { + // Create MinOp + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* op = MinOp(eager_ctx_, axis, axis); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + + // Check the attributes set by the call to MinOp above. + AddEagerOpToGraphAndCheck(op, [this, &axis](TF_Operation* graph_op) { + unsigned char value; + TF_OperationGetAttrBool(graph_op, "keep_dims", &value, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(value, 1); + TF_DataType dtype; + TF_OperationGetAttrType(graph_op, "Tidx", &dtype, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(dtype, TF_INT32); + TF_OperationGetAttrType(graph_op, "T", &dtype, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(dtype, TFE_TensorHandleDataType(axis)); + }); + TFE_DeleteTensorHandle(axis); + TFE_DeleteOp(op); +} + +TEST_F(AddEagerOpToGraphTest, ListAttributesArePreserved) { + // Create a "Squeeze" operator with list attributes. + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* squeeze = TFE_NewOp(eager_ctx_, "Squeeze", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpAddInput(squeeze, axis, status_); + TFE_OpSetAttrType(squeeze, "T", TF_INT32); + std::vector boundaries = {1, 2, 3, 4}; + TFE_OpSetAttrIntList(squeeze, "squeeze_dims", boundaries.data(), + boundaries.size()); + // Check attributes are preserved. + AddEagerOpToGraphAndCheck( + squeeze, [this, &boundaries](TF_Operation* squeeze_graph_op) { + TF_DataType dtype; + TF_OperationGetAttrType(squeeze_graph_op, "T", &dtype, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(dtype, TF_INT32); + std::unique_ptr list(new int64_t[boundaries.size()]); + TF_OperationGetAttrIntList(squeeze_graph_op, "squeeze_dims", list.get(), + boundaries.size(), status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + EXPECT_TRUE(std::equal(list.get(), list.get() + boundaries.size(), + boundaries.begin())); + }); + TFE_DeleteTensorHandle(axis); + TFE_DeleteOp(squeeze); +} + +TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) { + TFE_TensorHandle* scalar = TestScalarTensorHandle(); + TFE_Op* identityn = TFE_NewOp(eager_ctx_, "IdentityN", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + constexpr size_t kNumInputs = 3; + for (size_t i = 0; i < kNumInputs; ++i) { + TFE_OpAddInput(identityn, scalar, status_); + } + TF_DataType types[kNumInputs] = {TF_FLOAT, TF_FLOAT, TF_FLOAT}; + TFE_OpSetAttrTypeList(identityn, "T", types, kNumInputs); + AddEagerOpToGraphAndCheck( + identityn, [this, kNumInputs](TF_Operation* graph_op) { + EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs); + EXPECT_EQ(TF_OperationInputListLength(graph_op, "input", status_), + kNumInputs); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + EXPECT_EQ(TF_OperationOutputListLength(graph_op, "output", status_), + kNumInputs); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + }); + TFE_DeleteTensorHandle(scalar); + TFE_DeleteOp(identityn); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 28b9f8df9c873ee394eb6a241dd9ac06ba6c8796..03d65ecefd4a9ba5a23a94ed902dfba6dd4fbda9 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -162,6 +162,11 @@ Status FillFunctionBody( const std::vector& body_nodes, const std::unordered_map& tensor_renaming, FunctionDef* fdef) { + std::unordered_set func_attr_names; + for (const auto& func_attr : fdef->signature().attr()) { + func_attr_names.insert(func_attr.name()); + } + std::vector in_edges; std::vector control_edges; for (const Node* node : body_nodes) { @@ -243,6 +248,48 @@ Status FillFunctionBody( if (node->op_def().is_stateful()) { fdef->mutable_signature()->set_is_stateful(true); } + + // If this node has any attributes with placeholder value, add the + // attribute to FunctionDef signature. + for (const auto& iter : node->attrs()) { + if (iter.second.placeholder().empty()) { + continue; + } + + // If we already added the attribute, skip it. + string func_attr_name = iter.second.placeholder(); + if (func_attr_names.find(func_attr_name) != func_attr_names.end()) { + continue; + } + + // This node's attribute is a placeholder value, so it does not have type + // information. We check node's OpDef for attribute type. + string node_attr_name = iter.first; + const OpDef::AttrDef* node_attr_def = nullptr; + for (const auto& node_attr : node->op_def().attr()) { + if (node_attr.name() == node_attr_name) { + node_attr_def = &node_attr; + } + } + if (!node_attr_def) { +#ifdef TENSORFLOW_LITE_PROTOS + return errors::Unimplemented( + "Placeholder value is not supported for attributes not in OpDef. " + "Attribute: ", + node_attr_name); +#else + return errors::Unimplemented( + "Placeholder value is not supported for attributes not in OpDef. " + "Attribute: ", + node_attr_name, ", OpDef: ", node->op_def().DebugString()); +#endif + } + OpDef::AttrDef* attr_def = fdef->mutable_signature()->add_attr(); + attr_def->set_name(func_attr_name); + attr_def->set_type(node_attr_def->type()); + + func_attr_names.insert(func_attr_name); + } } return Status::OK(); } @@ -255,6 +302,8 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, const std::vector& inputs, const std::vector& outputs, const std::vector& output_names, + const std::vector& control_outputs, + const std::vector& control_output_names, const char* description, FunctionDef* fdef) { if (!output_names.empty()) { DCHECK_EQ(output_names.size(), outputs.size()); @@ -378,6 +427,29 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, fdef->mutable_signature()->set_name(fn_name); } + if (!control_output_names.empty() && + (control_outputs.size() != control_output_names.size())) { + return InvalidArgument( + "Expected number of control outputs (", control_outputs.size(), + ") and the number of control output names (", + control_output_names.size(), ") to match but they do not."); + } + std::unordered_set control_output_names_set; + for (int i = 0; i < control_outputs.size(); ++i) { + string signature_name; + if (!control_output_names.empty()) { + signature_name = control_output_names[i]; + } else { + signature_name = control_outputs[i]->name(); + } + if (!control_output_names_set.insert(signature_name).second) { + return errors::InvalidArgument("Repeated control output name: ", + signature_name); + } + fdef->mutable_signature()->add_control_output(signature_name); + (*fdef->mutable_control_ret())[signature_name] = control_outputs[i]->name(); + } + return Status::OK(); } @@ -485,14 +557,14 @@ Status ComputeBodyNodes( using tensorflow::Node; using tensorflow::string; -TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, - unsigned char append_hash_to_fn_name, - int num_opers, const TF_Operation* const* opers, - int ninputs, const TF_Output* inputs, - int noutputs, const TF_Output* outputs, - const char* const* output_names, - const TF_FunctionOptions* opts, - const char* description, TF_Status* status) { +TF_Function* TF_GraphToFunctionWithControlOutputs( + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + int ncontrol_outputs, const TF_Operation* const* control_outputs, + const char* const* control_output_names, const TF_FunctionOptions* opts, + const char* description, TF_Status* status) { tensorflow::mutex_lock l(*const_cast(&fn_body->mu)); // Process inputs. @@ -517,19 +589,34 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, } } + // Process control output names. + std::vector control_output_names_vec; + if (control_output_names) { + control_output_names_vec.reserve(ncontrol_outputs); + for (int i = 0; i < ncontrol_outputs; ++i) { + control_output_names_vec.push_back(string(output_names[i])); + } + } + // Compute body nodes. std::vector body_nodes; status->status = tensorflow::ComputeBodyNodes( fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes); if (!status->status.ok()) return nullptr; + // Compute body nodes. + std::vector control_output_nodes; + for (int i = 0; i < ncontrol_outputs; ++i) { + control_output_nodes.push_back(&control_outputs[i]->node); + } + // Do the actual function creation. TF_Function* tf_function = new TF_Function(); DCHECK(append_hash_to_fn_name <= 1); status->status = tensorflow::GraphToFunctionDef( fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes, - input_tensors, output_tensors, output_names_vec, description, - &tf_function->fdef); + input_tensors, output_tensors, output_names_vec, control_output_nodes, + control_output_names_vec, description, &tf_function->fdef); if (!status->status.ok()) { TF_DeleteFunction(tf_function); return nullptr; @@ -537,6 +624,20 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, return tf_function; } +TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, + int num_opers, const TF_Operation* const* opers, + int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, + const char* const* output_names, + const TF_FunctionOptions* opts, + const char* description, TF_Status* status) { + return TF_GraphToFunctionWithControlOutputs( + fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs, + inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts, + description, status); +} + const char* TF_FunctionName(TF_Function* func) { return func->fdef.signature().name().c_str(); } diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 73fe73769bc1219ce865149d67d333c53371ccc5..946f8c4a2c3fb25f908d809e00bf579b40a8668b 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" @@ -1230,6 +1231,53 @@ void DefineFunction(const char* name, TF_Function** func, ASSERT_NE(*func, nullptr); } +REGISTER_OP("CustomOp") + .Output("output: float32") + .Attr("index: int") + .SetShapeFn(tensorflow::shape_inference::UnknownShape); + +void NodeWithPlaceholderAttrHelper(TF_Graph* graph, TF_Status* s, + const char* name, const char* placeholder, + TF_Operation** op) { + TF_OperationDescription* desc = TF_NewOperation(graph, "CustomOp", name); + TF_SetAttrPlaceholder(desc, "index", placeholder); + *op = TF_FinishOperation(desc, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); +} + +TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) { + std::unique_ptr func_graph( + TF_NewGraph(), TF_DeleteGraph); + std::unique_ptr s(TF_NewStatus(), + TF_DeleteStatus); + + TF_Operation *node1, *node2, *node3; + NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node1", "v1", + &node1); + NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node2", "v1", + &node2); + NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2", + &node3); + + TF_Output inputs[] = {}; + TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}}; + func_ = TF_GraphToFunction( + func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1, + /*opers=*/nullptr, 0, inputs, 3, outputs, + /*output_names=*/nullptr, + /*opts=*/nullptr, /*description=*/nullptr, s.get()); + ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); + ASSERT_NE(func_, nullptr); + + // Verify that FunctionDef has 2 attributes, "v1" and "v2". + ASSERT_EQ(func_->fdef.signature().attr().size(), 2); + EXPECT_EQ(func_->fdef.signature().attr(0).name(), "v1"); + EXPECT_EQ(func_->fdef.signature().attr(0).type(), "int"); + EXPECT_EQ(func_->fdef.signature().attr(1).name(), "v2"); + EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int"); +} + TEST_F(CApiFunctionTest, SetGradientAndRun) { // Define the function and its grad DefineFunction(func_name_, &func_); diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 5ba26d3c585350aa510f9970cbfc246a9a108543..d520b6b76849e562def6abd8be0510d3b4797e8c 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -204,7 +204,8 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); -Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out); +Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, + TF_Buffer* out); // Set the shapes and types of the output's handle. // @@ -228,6 +229,8 @@ void RecordMutation(TF_Graph* graph, const TF_Operation& op, bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) LOCKS_EXCLUDED(session->graph->mu, session->mu); +std::string getTF_OutputDebugString(TF_Output node); + } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index d5934a10395ae094f65d3bc8b6cd7b94dbd32410..2be03bf0de6277fc63c353ad6dc63bec096a6993 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -163,6 +163,7 @@ TEST(CAPI, AllocateTensor) { EXPECT_EQ(dims[0], TF_Dim(t, 0)); EXPECT_EQ(dims[1], TF_Dim(t, 1)); EXPECT_EQ(num_bytes, TF_TensorByteSize(t)); + EXPECT_EQ(6, TF_TensorElementCount(t)); TF_DeleteTensor(t); } @@ -1467,6 +1468,41 @@ TEST(CAPI, DeletingNullPointerIsSafe) { TF_DeleteStatus(status); } +TEST(CAPI, TestBitcastFrom_Reshape) { + int64_t dims[] = {2, 3}; + TF_Tensor* a = + TF_AllocateTensor(TF_UINT64, dims, 2, 6 * TF_DataTypeSize(TF_UINT64)); + TF_Tensor* b = + TF_AllocateTensor(TF_UINT64, nullptr, 0, TF_DataTypeSize(TF_UINT64)); + EXPECT_NE(a, nullptr); + EXPECT_NE(b, nullptr); + + EXPECT_EQ(6, TF_TensorElementCount(a)); + EXPECT_EQ(1, TF_TensorElementCount(b)); + EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(a)); + EXPECT_EQ(TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(b)); + + int64_t new_dims[] = {3, 2}; + TF_Status* status = TF_NewStatus(); + TF_TensorBitcastFrom(a, TF_UINT64, b, new_dims, 2, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + + EXPECT_EQ(6, TF_TensorElementCount(a)); + EXPECT_EQ(6, TF_TensorElementCount(b)); + EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(a)); + EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(b)); + + // Check that a write to one tensor shows up in the other. + *(static_cast(TF_TensorData(a))) = 4; + EXPECT_EQ(4, *(static_cast(TF_TensorData(b)))); + *(static_cast(TF_TensorData(b))) = 6; + EXPECT_EQ(6, *(static_cast(TF_TensorData(a)))); + + TF_DeleteTensor(a); + TF_DeleteTensor(b); +} + REGISTER_OP("TestOpWithNoGradient") .Input("x: T") .Output("y: T") diff --git a/tensorflow/c/c_test.c b/tensorflow/c/c_test.c new file mode 100644 index 0000000000000000000000000000000000000000..7468122cd567270c8454f886e478be34c2c15cbf --- /dev/null +++ b/tensorflow/c/c_test.c @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/env.h" +#include "tensorflow/c/kernels.h" + +// A create function. This will never actually get called in this test, it's +// just nice to know that it compiles. +void* create(TF_OpKernelConstruction* ctx) { + TF_DataType type; + TF_Status* s = TF_NewStatus(); + TF_OpKernelConstruction_GetAttrType(ctx, "foobar", &type, s); + TF_DeleteStatus(s); + return NULL; +} + +// A compute function. This will never actually get called in this test, it's +// just nice to know that it compiles. +void compute(void* kernel, TF_OpKernelContext* ctx) { + TF_Tensor* input; + TF_Status* s = TF_NewStatus(); + TF_GetInput(ctx, 0, &input, s); + TF_DeleteTensor(input); + TF_DeleteStatus(s); +} + +// Exercises tensorflow's C API. +int main(int argc, char** argv) { + TF_InitMain(argv[0], &argc, &argv); + + struct TF_StringStream* s = TF_GetLocalTempDirectories(); + const char* path; + + if (!TF_StringStreamNext(s, &path)) { + fprintf(stderr, "TF_GetLocalTempDirectories returned no results\n"); + return 1; + } + + char file_name[100]; + struct timeval t; + if (gettimeofday(&t, NULL)) { + perror("gettimeofday failed"); + return 1; + } + snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t.tv_sec); + + size_t length = 2 + strlen(path) + strlen(file_name); + char* full_path = malloc(length); + snprintf(full_path, length, "%s/%s", path, file_name); + + TF_WritableFileHandle* h; + TF_Status* status = TF_NewStatus(); + TF_NewWritableFile(full_path, &h, status); + if (TF_GetCode(status) != TF_OK) { + fprintf(stderr, "TF_NewWritableFile failed: %s\n", TF_Message(status)); + return 1; + } + fprintf(stderr, "wrote %s\n", full_path); + free(full_path); + TF_CloseWritableFile(h, status); + if (TF_GetCode(status) != TF_OK) { + fprintf(stderr, "TF_CloseWritableFile failed: %s\n", TF_Message(status)); + } + TF_StringStreamDone(s); + + TF_KernelBuilder* b = + TF_NewKernelBuilder("SomeOp", "SomeDevice", &create, &compute, NULL); + TF_RegisterKernelBuilder("someKernel", b, status); + + TF_DeleteStatus(status); + return 0; +} diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index c34a84fcfee9b6ba9a7be86ae16e2856a2d343c7..282f0da302fac89c6fae9f8b5aa4b3c33ab93532 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -3,11 +3,19 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", - "tf_cuda_cc_test", - "tf_cc_test", "tf_copts", - "tfe_xla_copts", + "tf_cuda_cc_test", "tf_cuda_library", + "tfe_xla_copts", +) +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_additional_device_tracer_test_flags", + "tf_kernel_tests_linkstatic", +) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", ) tf_cuda_library( @@ -62,6 +70,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:remote_device", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/profiler/lib:profiler_session", "//tensorflow/core:gpu_runtime", ], ) @@ -101,6 +110,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "//tensorflow/core/profiler/lib:profiler_session", ], ) @@ -148,6 +158,88 @@ tf_cuda_cc_test( ], ) +tf_cuda_library( + name = "c_api_experimental", + srcs = [ + "c_api_experimental.cc", + ], + hdrs = ["c_api_experimental.h"], + copts = tf_copts() + tfe_xla_copts(), + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + ":c_api", + ":c_api_internal", + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:attr_builder", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:execute", + "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/common_runtime/eager:copy_to_device_node", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], + }) + select({ + "//tensorflow:with_xla_support": [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:xla_device", + ], + "//conditions:default": [], + }) + [ + "@com_google_absl//absl/memory", + "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", + "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/profiler/rpc:profiler_server", + "//tensorflow/core/profiler/rpc/client:capture_profile", + "//tensorflow/core:gpu_runtime", + ], +) + +tf_cuda_cc_test( + name = "c_api_experimental_test", + size = "small", + srcs = [ + "c_api_experimental_test.cc", + ], + args = + ["--heap_check=local"] + tf_additional_device_tracer_test_flags(), + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["nomac"], + deps = [ + ":c_api_experimental", + ":c_api_test_util", + "//tensorflow/c:c_test_util", + "//tensorflow/cc/profiler", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/profiler:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "tape", hdrs = ["tape.h"], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 027d752f420238da867cb9d8c116640e1730caaa..45701c7fcf02d5e6ec464ae10d4d20f20ba1d9f0 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -356,6 +356,8 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { if (h == nullptr) return; + VLOG(1) << "Deleting tensor handle " << h << " with internal handle " + << h->handle; if (h->handle) { h->handle->Unref(); } @@ -443,15 +445,15 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { return nullptr; } // TODO(agarwal): move this implementation inside TFE_TensorHandle. - tensorflow::Device* d = nullptr; - tensorflow::Device* op_device = nullptr; const tensorflow::Tensor* t = nullptr; - status->status = h->handle->TensorAndDevice(&t, &d, &op_device); - if (!status->status.ok()) return nullptr; tensorflow::TensorHandle* h_cpu = nullptr; - if (!IsCPU(d)) { - status->status = h->handle->CopyToDevice( - h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu); + tensorflow::Device* d = nullptr; + tensorflow::Device* op_device = nullptr; + + if (h->handle->IsRemote()) { + status->status = EagerCopyToDevice( + h->handle, h->handle->Context(), + h->handle->Context()->HostCPU()->name().c_str(), &h_cpu); if (!status->status.ok()) { return nullptr; } @@ -460,6 +462,22 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { h_cpu->Unref(); return nullptr; } + } else { + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) return nullptr; + + if (!IsCPU(d)) { + status->status = h->handle->CopyToDevice( + h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu); + if (!status->status.ok()) { + return nullptr; + } + status->status = h_cpu->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) { + h_cpu->Unref(); + return nullptr; + } + } } TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status); if (h_cpu != nullptr) { @@ -696,6 +714,7 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { + VLOG(1) << "Calling TFE_Execute() on op " << op; tensorflow::gtl::InlinedVector handle_retvals( *num_retvals); status->status = @@ -738,12 +757,18 @@ void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, status->status = ctx->context.AddFunctionDef(function->fdef); } +unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { + return ctx->context.FindFunctionDef(name) != nullptr; +} + void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreMetadata(true); + ctx->context.SetShouldStoreGraphs(true); + ctx->context.SetShouldStoreStepStats(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreMetadata(false); + ctx->context.SetShouldStoreGraphs(false); + ctx->context.SetShouldStoreStepStats(false); } } // extern "C" @@ -774,7 +799,7 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, if (!status->status.ok()) return; tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); - ctx->context.RunMetadataProto()->Clear(); + ctx->context.ClearRunMetadata(); } namespace { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index f80ae5a6d02d4d613c95cf8486e0fc0aeed3affc..044dfb7415b027b707af05a197fdb41fe1f6d2e5 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -170,23 +170,11 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status); -// Returns the device of the operation that produced `h`. -// If `h` was produced by a copy, returns the destination device of -// the copy. Note that returned device name is not always the device -// holding the tensor handle's memory. If you want the latter, use -// TFE_TensorHandleBackingDeviceName. -// This function will block till the operation that produces `h` has completed. -// -// Device on which the kernel of the operation that produced `h` ran. -// -// If `h` was produced by a copy, returns the destination device of -// the copy. -// -// Note that returned device name is not always the device that owns the memory -// that backs the tensor handle. For the latter see -// TFE_TensorHandleBackingDeviceName. -// -// This function will block till the operation that produces `h` has completed. +// Returns the device of the operation that produced `h`. If `h` was produced by +// a copy, returns the destination device of the copy. Note that the returned +// device name is not always the device holding the tensor handle's memory. If +// you want the latter, use TFE_TensorHandleBackingDeviceName. This function +// will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); @@ -405,6 +393,10 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status); +// Checks whether a function is registered under `name`. +TF_CAPI_EXPORT unsigned char TFE_ContextHasFunction(TFE_Context* ctx, + const char* name); + // Enables tracing of RunMetadata on the ops executed from this context. TF_CAPI_EXPORT extern void TFE_ContextEnableRunMetadata(TFE_Context* ctx); diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index 52b0824552855860dfb138f3ac9a5d3afa7dc965..ffcd5ace0b98597363abe63201bf6c328a03212f 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -83,7 +83,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( } } - if (xla::ShapeUtil::IsTuple(padded_shape)) { + if (padded_shape.IsTuple()) { if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) { // Currently, the only case of XlaTensor containing a tuple shape is to // represent 64 bit ints, doubles, and complex numbers (we don't support @@ -99,7 +99,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0); const xla::Shape& shape1 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); - if (xla::ShapeUtil::IsTuple(shape0) || xla::ShapeUtil::IsTuple(shape1)) { + if (shape0.IsTuple() || shape1.IsTuple()) { status->status = tensorflow::errors::InvalidArgument( "XlaTensors should not contain nested tuples. Shape: ", padded_shape.DebugString()); diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff798593b5f2f77339b668668ff6dafb9f44a2b3 --- /dev/null +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_experimental.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/core/profiler/rpc/client/capture_profile.h" +#include "tensorflow/core/profiler/rpc/profiler_server.h" + +using tensorflow::string; + +void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { + op->operation.ConsumeInput(h->handle); +} + +TFE_Profiler* TFE_NewProfiler(TFE_ProfilerContext* ctx) { + return new TFE_Profiler(ctx); +} + +bool TFE_ProfilerIsOk(TFE_Profiler* profiler) { + return profiler->profiler->Status().ok(); +} + +void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; } + +void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler, + TF_Buffer* buf, TF_Status* status) { + TFE_ContextAsyncWait(ctx, status); + if (!status->status.ok()) return; + string content; + status->status = profiler->profiler->SerializeToString(&content); + void* data = tensorflow::port::Malloc(content.length()); + content.copy(static_cast(data), content.length(), 0); + buf->data = data; + buf->length = content.length(); + buf->data_deallocator = [](void* data, size_t length) { + tensorflow::port::Free(data); + }; +} + +TFE_ProfilerContext* TFE_NewProfilerContext() { + return new TFE_ProfilerContext; +} + +void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context, + TFE_Context* eager_context) { + profiler_context->profiler_context.eager_context = &eager_context->context; +} + +void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) { + delete profiler_context; +} + +void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) { + // Release child thread intentionally. The child thread can be terminate by + // terminating the main thread. + tensorflow::StartProfilerServer(&context->profiler_context, port).release(); +} + +void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { + ctx->context.SetShouldStoreGraphs(true); +} + +void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { + ctx->context.SetShouldStoreGraphs(false); +} + +bool TFE_ProfilerClientStartTracing(char* service_addr, char* logdir, + char* worker_list, bool include_dataset_ops, + int duration_ms, int num_tracing_attempts) { + tensorflow::Status s = + tensorflow::profiler::client::ValidateHostPortPair(service_addr); + if (!s.ok()) { + return false; + } + s = tensorflow::profiler::client::StartTracing( + service_addr, logdir, worker_list, include_dataset_ops, duration_ms, + num_tracing_attempts); + return s.ok(); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h new file mode 100644 index 0000000000000000000000000000000000000000..89523793d37b89ee49c4db844a85f019381ff730 --- /dev/null +++ b/tensorflow/c/eager/c_api_experimental.h @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ +#define TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, + TF_Status* status); + +typedef struct TFE_ProfilerContext TFE_ProfilerContext; + +// A profiler which will start profiling when creating the object and will stop +// when the object is destroyed. It will profile all operations run under the +// given TFE_Context. Multiple instance of it can be created, but at most one +// of them will profile for each TFE_Context. +// Thread-safety: TFE_Profiler is thread-safe. +typedef struct TFE_Profiler TFE_Profiler; + +TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler(TFE_ProfilerContext* ctx); +TF_CAPI_EXPORT extern bool TFE_ProfilerIsOk(TFE_Profiler* profiler); +TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler); + +// The output string is a binary string of tensorflow.tpu.Trace. User can write +// the string to file for offline analysis by tensorboard. +TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Context* ctx, + TFE_Profiler* profiler, + TF_Buffer* buf, + TF_Status* status); + +// Return a new profiler context object. +TF_CAPI_EXPORT extern TFE_ProfilerContext* TFE_NewProfilerContext(void); + +// Set the eager context in TFE_ProfilerServerOptions +TF_CAPI_EXPORT extern void TFE_ProfilerContextSetEagerContext( + TFE_ProfilerContext* profiler_context, TFE_Context* eager_context); + +// Destroy a profiler context object. +TF_CAPI_EXPORT extern void TFE_DeleteProfilerContext( + TFE_ProfilerContext* profiler_context); + +// Start a profiler grpc server which listens to specified port. It will start +// the server on its own thread. It can be shutdown by terminating tensorflow. +// It can be used in both Eager mode and graph mode. Creating multiple profiler +// server is allowed. The service defined in +// tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use +// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture tracable +// file following +// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace. +TF_CAPI_EXPORT extern void TFE_StartProfilerServer(TFE_ProfilerContext* context, + int port); + +// Enables only graph collection in RunMetadata on the functions executed from +// this context. +TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx); + +// Disables only graph collection in RunMetadata on the functions executed from +// this context. +TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx); + +// Send a grpc request to profiler server (service_addr) to perform on-demand +// profiling and save the result into logdir which can be visualized by +// TensorBoard. worker_list is the list of worker TPUs separated by ','. Set +// include_dataset_opts to false to profile longer traces. It will block the +// caller thread until receives tracing result. +// This API is designed for TensorBoard, for end user, please use +// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following +// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace. +TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing( + char* service_addr, char* logdir, char* worker_list, + bool include_dataset_ops, int duration_ms, int num_tracing_attempts); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d85048caa7c7f727271352883cb834a2575bd251 --- /dev/null +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_experimental.h" + +#include +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/cc/profiler/profiler.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/profiler/trace_events.pb.h" + +using tensorflow::string; + +namespace tensorflow { +namespace { + +static bool HasSubstr(absl::string_view base, absl::string_view substr) { + bool ok = str_util::StrContains(base, substr); + EXPECT_TRUE(ok) << base << ", expected substring " << substr; + return ok; +} + +void ExecuteWithProfiling(bool async) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status); + TFE_ProfilerContext* profiler_context = TFE_NewProfilerContext(); + TFE_ProfilerContextSetEagerContext(profiler_context, ctx); + TFE_Profiler* profiler = TFE_NewProfiler(profiler_context); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + TFE_DeleteProfilerContext(profiler_context); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + + // Run op on GPU if it is present. + string gpu_device_name; + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { + TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + const char* device_name = TFE_OpGetDevice(matmul, status); + ASSERT_TRUE(strstr(device_name, "GPU:0") != nullptr); + } + + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + TF_Buffer* profiler_result = TF_NewBuffer(); + TFE_ProfilerSerializeToString(ctx, profiler, profiler_result, status); + TFE_DeleteProfiler(profiler); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + profiler::Trace profile_proto; + EXPECT_TRUE(profile_proto.ParseFromString( + {reinterpret_cast(profiler_result->data), + profiler_result->length})); + string profile_proto_str = profile_proto.DebugString(); + if (!gpu_device_name.empty()) { + EXPECT_TRUE(HasSubstr(profile_proto_str, "GPU:0")); + // device name with "stream:all" is collected by Device Tracer. + EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all")); + } + EXPECT_TRUE(HasSubstr(profile_proto_str, "CPU:0")); + TF_DeleteBuffer(profiler_result); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + TF_DeleteStatus(status); +} +TEST(CAPI, ExecuteWithTracing) { ExecuteWithProfiling(false); } +TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); } + +TEST(CAPI, MultipleProfilerSession) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(false)); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ProfilerContext* profiler_context = TFE_NewProfilerContext(); + TFE_ProfilerContextSetEagerContext(profiler_context, ctx); + + TFE_Profiler* profiler1 = TFE_NewProfiler(profiler_context); + EXPECT_TRUE(TFE_ProfilerIsOk(profiler1)); + + TFE_Profiler* profiler2 = TFE_NewProfiler(profiler_context); + EXPECT_FALSE(TFE_ProfilerIsOk(profiler2)); + + TFE_DeleteProfiler(profiler1); + TFE_DeleteProfiler(profiler2); + TFE_DeleteProfilerContext(profiler_context); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 67bc1bcd24605f8363d6a7c8d5d6a0836a42fc82..a563e4b8f50f2a90497736f4cb9ca234400bfa04 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/profiler/lib/profiler_session.h" #include "tensorflow/core/public/version.h" struct TFE_ContextOptions { @@ -82,6 +83,12 @@ struct TFE_TensorHandle { TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} tensorflow::TensorHandle* handle; + + // Create a symbolic tensor. + TFE_TensorHandle(TF_Output t, TF_DataType dtype) + : handle(new tensorflow::TensorHandle( + tensorflow::OutputGraphNode{t.oper, t.index}, + static_cast(dtype))) {} }; struct TFE_TensorDebugInfo { @@ -100,6 +107,18 @@ struct TFE_Op { tensorflow::EagerOperation operation; }; +struct TFE_ProfilerContext { + tensorflow::ProfilerContext profiler_context; +}; + +struct TFE_Profiler { + TFE_Profiler(TFE_ProfilerContext* ctx) { + profiler = tensorflow::ProfilerSession::Create(&ctx->profiler_context); + } + + std::unique_ptr profiler; +}; + namespace tensorflow { // Set an AttrValue on the op. Doesn't handle the list types. void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, @@ -107,4 +126,24 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const char* attr_name, TF_Status* status); } // namespace tensorflow +struct TFE_TraceContext { + TF_Graph* const graph; + + unsigned int node_counter = 0; + // Each tensor handle will have its ref count incremented when it's added as a + // map key, and decremented when this object is destroyed. + std::map input_tensor_map; + std::vector>* input_tensors = + nullptr; + + TFE_TraceContext(TF_Graph* graph) : graph(graph) {} + + ~TFE_TraceContext() { + delete input_tensors; + for (auto input : input_tensor_map) { + input.first->Unref(); + } + } +}; + #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 6b39b79ee82f9c7baaf856e573a42b7da65691e5..3d1ca4fb4b561a03ea9d879b1876fb1fd08a3139 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -175,13 +175,8 @@ void TestRemoteExecute(bool async) { TFE_Execute(matmul, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - auto* retval_task0 = TFE_TensorHandleCopyToDevice( - retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteTensorHandle(retval_task0); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc index 07b9e8b940c55caf62ae0b81b884bf313d335459..1c35ff9001d0ee1ab0fbae9e1bcc07116fab1065 100644 --- a/tensorflow/c/env.cc +++ b/tensorflow/c/env.cc @@ -159,3 +159,25 @@ TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void) { TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void) { return ::tensorflow::Env::Default()->NowSeconds(); } + +void TF_DefaultThreadOptions(TF_ThreadOptions* options) { + options->stack_size = 0; + options->guard_size = 0; + options->numa_node = -1; +} + +TF_Thread* TF_StartThread(const TF_ThreadOptions* options, + const char* thread_name, void (*work_func)(void*), + void* param) { + ::tensorflow::ThreadOptions cc_options; + cc_options.stack_size = options->stack_size; + cc_options.guard_size = options->guard_size; + cc_options.numa_node = options->numa_node; + return reinterpret_cast(::tensorflow::Env::Default()->StartThread( + cc_options, thread_name, [=]() { (*work_func)(param); })); +} + +void TF_JoinThread(TF_Thread* thread) { + // ::tensorflow::Thread joins on destruction + delete reinterpret_cast<::tensorflow::Thread*>(thread); +} diff --git a/tensorflow/c/env.h b/tensorflow/c/env.h index 9d27c5da37735042c7476b591e57486dbde33152..73078fcbbc5ae4c042f4a992655072a838e42915 100644 --- a/tensorflow/c/env.h +++ b/tensorflow/c/env.h @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + #ifndef TENSORFLOW_C_ENV_H_ #define TENSORFLOW_C_ENV_H_ @@ -21,13 +25,14 @@ limitations under the License. // -------------------------------------------------------------------------- // C API for tensorflow::Env. -struct TF_WritableFileHandle; -struct TF_StringStream; - #ifdef __cplusplus extern "C" { #endif +typedef struct TF_WritableFileHandle TF_WritableFileHandle; +typedef struct TF_StringStream TF_StringStream; +typedef struct TF_Thread TF_Thread; + typedef struct TF_FileStatistics { // The length of the file in bytes. int64_t length; @@ -37,6 +42,20 @@ typedef struct TF_FileStatistics { bool is_directory; } TF_FileStatistics; +typedef struct TF_ThreadOptions { + // Thread stack size to use (in bytes), zero implies that the system default + // will be used. + size_t stack_size; + + // Guard area size to use near thread stacks to use (in bytes), zero implies + // that the system default will be used. + size_t guard_size; + + // The NUMA node to use, -1 implies that there should be no NUMA affinity for + // this thread. + int numa_node; +} TF_ThreadOptions; + // Creates the specified directory. Typical status code are: // * TF_OK - successfully created the directory // * TF_ALREADY_EXISTS - directory already exists @@ -150,6 +169,25 @@ TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void); // Returns the number of seconds since the Unix epoch. TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void); +// Populates a TF_ThreadOptions struct with system-default values. +TF_CAPI_EXPORT extern void TF_DefaultThreadOptions(TF_ThreadOptions* options); + +// Returns a new thread that is running work_func and is identified +// (for debugging/performance-analysis) by thread_name. +// +// The given param (which may be null) is passed to work_func when the thread +// starts. In this way, data may be passed from the thread back to the caller. +// +// Caller takes ownership of the result and must call TF_JoinThread on it +// eventually. +TF_CAPI_EXPORT extern TF_Thread* TF_StartThread(const TF_ThreadOptions* options, + const char* thread_name, + void (*work_func)(void*), + void* param); + +// Waits for the given thread to finish execution, then deletes it. +TF_CAPI_EXPORT extern void TF_JoinThread(TF_Thread* thread); + #ifdef __cplusplus } #endif diff --git a/tensorflow/c/env_test.cc b/tensorflow/c/env_test.cc index e2206c6befd2167346c64032940d6e8c631e4a3e..687ad024137352662759ec1f43df87e89faca353 100644 --- a/tensorflow/c/env_test.cc +++ b/tensorflow/c/env_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -98,3 +99,29 @@ TEST(TestEnv, TestTimeFunctions) { ASSERT_GE(TF_NowMicros(), 946684800 * 1e6); ASSERT_GE(TF_NowNanos(), 946684800 * 1e9); } + +namespace { + +struct SomeThreadData { + ::tensorflow::mutex mu; + bool did_work = false; +}; + +void SomeThreadFunc(void* data) { + auto* real_data = static_cast(data); + ::tensorflow::mutex_lock l(real_data->mu); + real_data->did_work = true; +} + +} // namespace + +TEST(TestEnv, TestThreads) { + TF_ThreadOptions options; + TF_DefaultThreadOptions(&options); + SomeThreadData data; + TF_Thread* thread = + TF_StartThread(&options, "SomeThreadName", &SomeThreadFunc, &data); + TF_JoinThread(thread); + ::tensorflow::mutex_lock l(data.mu); + ASSERT_TRUE(data.did_work); +} diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 2a4eaecb6cf2740a522b1e849d1306ebde6c4577..71181ae430ab64106e2a75937bd54fbf2efc61ac 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -48,9 +48,10 @@ TF_KernelBuilder* TF_NewKernelBuilder( } void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) { - DCHECK_NE(builder, nullptr); - delete builder->cc_builder; - delete builder; + if (builder != nullptr) { + delete builder->cc_builder; + delete builder; + } } namespace tensorflow { @@ -158,3 +159,41 @@ void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor, cc_ctx->set_output(i, cc_tensor); } } + +void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); + ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status)); + cc_ctx->CtxFailure(s); +} + +void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status)); + cc_ctx->CtxFailure(s); +} + +#define DEFINE_TF_GETATTR(func, c_type, cc_type) \ + void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \ + const char* attr_name, \ + c_type* val, TF_Status* status) { \ + TF_SetStatus(status, TF_OK, ""); \ + cc_type v; \ + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \ + ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); \ + ::tensorflow::Set_TF_Status_from_Status(status, s); \ + if (s.ok()) { \ + *val = static_cast(v); \ + } \ + } + +DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType) + +TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return static_cast(cc_ctx->expected_output_dtype(i)); +} + +int64_t TF_StepId(TF_OpKernelContext* ctx) { + return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id(); +} diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index 1a91aa184f11ac8e45b38a1d106c7b445747a7c1..c47bfa8aa3a721d422a0a1536b924f3e53793193 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -35,9 +35,9 @@ extern "C" { // `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided // kernels when necessary. -struct TF_KernelBuilder; -struct TF_OpKernelConstruction; -struct TF_OpKernelContext; +typedef struct TF_KernelBuilder TF_KernelBuilder; +typedef struct TF_OpKernelConstruction TF_OpKernelConstruction; +typedef struct TF_OpKernelContext TF_OpKernelContext; // Allocates a new kernel builder and returns a pointer to it. // @@ -111,6 +111,32 @@ TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor, TF_Status* status); +// Notifies the given OpKernelConstruction that kernel construction has failed. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_Failure( + TF_OpKernelConstruction* ctx, TF_Status* status); + +// Notifies the given OpKernelContext that the kernel's compute function has +// failed. +TF_CAPI_EXPORT extern void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, + TF_Status* status); + +// Returns the expected output data type of the ith output. If i < 0 or +// i >= TF_NumOutputs(ctx), the program aborts. +TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType( + TF_OpKernelContext* ctx, int i); + +// Returns the step ID of the given context. +TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx); + +// Interprets the named kernel construction attribute as a TF_DataType and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// TF_DataType, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrType( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* val, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index e659ee3c3d258a626ccf03a782ec031b5a703a48..608887722f7bca44c884a3426d5e378e9387a530 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/kernels.h" #include "tensorflow/c/c_api.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/op.h" @@ -35,12 +36,24 @@ static void* MyCreateFunc(TF_OpKernelConstruction* ctx) { struct MyCustomKernel* s = new struct MyCustomKernel; s->created = true; s->compute_called = false; + + // Exercise attribute reads. + TF_DataType type; + TF_Status* status = TF_NewStatus(); + TF_OpKernelConstruction_GetAttrType(ctx, "SomeDataTypeAttr", &type, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + EXPECT_EQ(TF_FLOAT, type); + TF_DeleteStatus(status); + return s; } static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) { struct MyCustomKernel* s = static_cast(kernel); s->compute_called = true; + if (ctx != nullptr) { + EXPECT_EQ(43, TF_StepId(ctx)); + } } static void MyDeleteFunc(void* kernel) { @@ -61,6 +74,11 @@ static std::unique_ptr GetFakeKernel(const char* device_name, def.set_device(device_name); def.add_input("input1"); def.add_input("input2"); + + AttrValue v; + v.set_type(DataType::DT_FLOAT); + (*def.mutable_attr())["SomeDataTypeAttr"] = v; + return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1, status); } @@ -75,7 +93,8 @@ TEST(TestKernel, TestRegisterKernelBuilder) { REGISTER_OP(op_name) .Input("input1: double") .Input("input2: uint8") - .Output("output1: uint8"); + .Output("output1: uint8") + .Attr("SomeDataTypeAttr: type"); TF_KernelBuilder* builder = TF_NewKernelBuilder( op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc); @@ -126,7 +145,8 @@ TEST(TestKernel, TestInputAndOutputCount) { REGISTER_OP(op_name) .Input("input1: double") .Input("input2: uint8") - .Output("output1: uint8"); + .Output("output1: uint8") + .Attr("SomeDataTypeAttr: type"); static int num_inputs = 0; static int num_outputs = 0; @@ -155,6 +175,8 @@ TEST(TestKernel, TestInputAndOutputCount) { TF_SetOutput(ctx, 24, input, s); EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + EXPECT_EQ(TF_UINT8, TF_ExpectedOutputDataType(ctx, 0)); + TF_DeleteStatus(s); if (input != nullptr) { TF_DeleteTensor(input); @@ -175,6 +197,7 @@ TEST(TestKernel, TestInputAndOutputCount) { OpKernelContext::Params p; DummyDevice dummy_device(nullptr, false); p.device = &dummy_device; + p.step_id = 43; Tensor t(tensorflow::uint8(123)); @@ -200,4 +223,8 @@ TEST(TestKernel, TestInputAndOutputCount) { } } +TEST(TestKernel, DeleteKernelBuilderIsOkOnNull) { + TF_DeleteKernelBuilder(nullptr); +} + } // namespace tensorflow diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index a09becc49b10d2c58f98fbcc11df5190f794c1d4..4c4d587fce04d101b3cc8faebcc3ba04f2f1d0cf 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -150,6 +150,7 @@ cc_library_with_android_deps( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", ], ) @@ -586,6 +587,25 @@ tf_gen_op_wrappers_cc( pkg = "//tensorflow/core", ) +tf_gen_op_wrappers_cc( + name = "tpu_ops", + include_internal_ops = 1, + op_lib_names = [ + "tpu_configuration_ops", + "tpu_cross_replica_ops", + "tpu_embedding_ops", + "tpu_functional_ops", + "tpu_heartbeat_ops", + "tpu_host_compute_ops", + "tpu_infeed_ops", + "tpu_outfeed_ops", + "tpu_ordinal_selector_ops", + "tpu_replication_ops", + ], + pkg = "//tensorflow/core", + visibility = ["//tensorflow:internal"], +) + cc_library_with_android_deps( name = "cc_op_gen_main", srcs = [ diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 39593370d1c243e84dc5b6091724d1d404c102b0..43a33cbea6e1e4a50f61cc7d6d8d70cac6a603d2 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -321,6 +321,7 @@ std::pair AttrTypeName(StringPiece attr_type) { {"tensor", {"TensorProto", true}}, {"list(tensor)", {"gtl::ArraySlice", true}}, {"func", {"NameAttrList", true}}, + {"list(func)", {"gtl::ArraySlice", true}}, }; auto entry = attr_type_map->find(attr_type); diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index affd90b1bcc7cb4a8b3ffed6aeeb4bd480f5e314..a7e645e8b556f14f0c7a51d2eba6ab1e2256b837 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -96,7 +96,7 @@ class SymbolicGradientBuilder { // Used to identify nodes at which to stop backprop. std::unordered_set GetStopBackpropNodes( const std::vector& reachable_nodes, - std::unordered_set output_nodes); + const std::unordered_set& output_nodes); const Scope& scope_; const ops::GradOpRegistry* registry_; @@ -167,7 +167,6 @@ Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad, std::vector SymbolicGradientBuilder::GetReachableNodes() { std::vector reachable_nodes(scope_.graph()->num_node_ids(), false); std::deque queue; - std::vector visited(scope_.graph()->num_node_ids(), false); for (const Output& out : outputs_) { if (!reachable_nodes[out.node()->id()]) { queue.push_back(out.node()); @@ -180,10 +179,10 @@ std::vector SymbolicGradientBuilder::GetReachableNodes() { queue.pop_front(); for (const Edge* e : n->in_edges()) { if (e->IsControlEdge()) continue; - if (visited[e->src()->id()]) continue; - queue.push_back(e->src()); - reachable_nodes[e->src()->id()] = true; - visited[e->src()->id()] = true; + if (!reachable_nodes[e->src()->id()]) { + queue.push_back(e->src()); + reachable_nodes[e->src()->id()] = true; + } } } return reachable_nodes; @@ -191,7 +190,7 @@ std::vector SymbolicGradientBuilder::GetReachableNodes() { std::unordered_set SymbolicGradientBuilder::GetStopBackpropNodes( const std::vector& reachable_nodes, - std::unordered_set output_nodes) { + const std::unordered_set& output_nodes) { // Output nodes that get transitively consumed by other `outputs_` are stored // in `internal_outputs`. std::unordered_set internal_outputs; diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc index 882709e1e2817431a32c453fe0f35f2b2e6c69b0..05c287bdc62cdb8be7208ce3975f280aaa816766 100644 --- a/tensorflow/cc/gradients/image_grad.cc +++ b/tensorflow/cc/gradients/image_grad.cc @@ -69,6 +69,23 @@ Status ResizeBicubicGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("ResizeBicubic", ResizeBicubicGradHelper); +Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + string kernel_type; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "kernel_type", &kernel_type)); + grad_outputs->push_back(internal::ScaleAndTranslateGrad( + scope, grad_inputs[0], op.input(0), op.input(2), op.input(3), + internal::ScaleAndTranslateGrad::KernelType(kernel_type))); + + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc index 2e55c7561b030c50bd67bd53fd0d55710085c5d2..1d150226538093467e092e02f38090a327f9c9b6 100644 --- a/tensorflow/cc/gradients/image_grad_test.cc +++ b/tensorflow/cc/gradients/image_grad_test.cc @@ -30,6 +30,7 @@ using ops::Const; using ops::ResizeBicubic; using ops::ResizeBilinear; using ops::ResizeNearestNeighbor; +using ops::ScaleAndTranslate; class ImageGradTest : public ::testing::Test { protected: @@ -153,5 +154,45 @@ TEST_F(ImageGradTest, TestBicubic) { TestResize(RESIZE_BICUBIC); } +class ScaleAndTranslateGradTest : public ::testing::Test { + protected: + ScaleAndTranslateGradTest() : scope_(Scope::NewRootScope()) {} + + template + Tensor MakeData(const TensorShape& data_shape) { + DataType data_type = DataTypeToEnum::v(); + Tensor data(data_type, data_shape); + auto data_flat = data.flat(); + for (int i = 0; i < data_flat.size(); ++i) { + data_flat(i) = T(i); + } + return data; + } + + template + void MakeOp(const Tensor& x_data, const Input& y_shape, Output* x, + Output* y) { + *x = Const(scope_, x_data); + *y = ScaleAndTranslate(scope_, *x, y_shape, {1.8f, 2.1f}, {0.5f, 0.7f}); + TF_ASSERT_OK(scope_.status()); + } + + template + void TestResize() { + TensorShape x_shape({1, 2, 3, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(x_data, {4, 6}, &x, &y); + JAC_T max_error; + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_data, y, {1, 4, 6, 1}, &max_error))); + EXPECT_LT(max_error, 1e-3); + } + + Scope scope_; +}; + +TEST_F(ScaleAndTranslateGradTest, Works) { TestResize(); } + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/profiler/BUILD b/tensorflow/cc/profiler/BUILD index cf65fe1ab99b49207a64e86310178141b30d07d7..e9838d9aba6554b40082187057851e9c896f8352 100644 --- a/tensorflow/cc/profiler/BUILD +++ b/tensorflow/cc/profiler/BUILD @@ -10,7 +10,7 @@ tf_cuda_cc_test( name = "profiler_test", srcs = ["profiler_test.cc"], tags = [ - "noguitar", # b/77649654 + "nogpu", # b/77649654 ], deps = [ ":profiler", diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 52345a376cc29ee47ccb9888c9bb26292468b5a9..dedd55f16afb879ea966dc89d14d88ee15d9e83e 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -81,6 +81,7 @@ cc_library( ] + if_not_mobile([ "//tensorflow/core:core_cpu", "//tensorflow/core:lib", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", ]) + if_android([ diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 85d3dd01fa51b3c3ba6fcbf5faac03f1ff5630e2..66260fcf4a9b24f78d45010c6e86d4ee398b6d3d 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -21,11 +21,11 @@ limitations under the License. #include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf_internal.h" -#include "tensorflow/core/protobuf/saved_model.pb.h" #include "tensorflow/core/protobuf/saver.pb.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" @@ -42,9 +42,28 @@ auto* load_latency = monitoring::Counter<1>::New( "/tensorflow/cc/saved_model/load_latency", "Latency in microseconds for SavedModels that were successfully loaded.", "model_path"); +auto* load_latency_by_stage = monitoring::Sampler<2>::New( + { + "/tensorflow/cc/saved_model/load_latency_by_stage", // metric name + "Distribution of wall time spent (in microseconds) in each stage " + "(restore graph from disk, run init graph op, etc) when loading the " + "model", + "model_path", + "stage", + }, + // Scale of 10, power of 1.8 with bucket count 33 (~20 minutes). + monitoring::Buckets::Exponential(10, 1.8, 33)); + constexpr char kLoadAttemptFail[] = "fail"; constexpr char kLoadAttemptSuccess[] = "success"; +uint64 GetLatencyMicroseconds(const uint64 start_microseconds) { + const uint64 end_microseconds = Env::Default()->NowMicros(); + // Avoid clock skew. + if (end_microseconds < start_microseconds) return 0; + return end_microseconds - start_microseconds; +} + Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr* session) { @@ -242,6 +261,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, const string& export_dir, const std::unordered_set& tags, SavedModelBundle* const bundle) { + const uint64 read_start_microseconds = Env::Default()->NowMicros(); TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags, &bundle->meta_graph_def)); @@ -256,12 +276,23 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, bundle->meta_graph_def.saver_def().restore_op_name(), bundle->meta_graph_def.saver_def().filename_tensor_name(), asset_file_defs, bundle->session.get())); + // Record walltime spent in restoring graph from disk, but postpone metric + // increments until graph init finishes. + const uint64 restore_graph_walltime = + GetLatencyMicroseconds(read_start_microseconds); + + const uint64 graph_init_start_microseconds = Env::Default()->NowMicros(); string init_op_name; TF_RETURN_IF_ERROR( GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name)); TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def, asset_file_defs, bundle->session.get(), init_op_name)); + load_latency_by_stage->GetCell(export_dir, "restore_graph") + ->Add(restore_graph_walltime); + // Record wall time spent in init op. + load_latency_by_stage->GetCell(export_dir, "init_graph") + ->Add(GetLatencyMicroseconds(graph_init_start_microseconds)); return Status::OK(); } @@ -275,16 +306,10 @@ Status LoadSavedModel(const SessionOptions& session_options, const uint64 start_microseconds = Env::Default()->NowMicros(); const Status status = LoadSavedModelInternal(session_options, run_options, export_dir, tags, bundle); - const uint64 load_latency_microsecs = [&]() -> uint64 { - const uint64 end_microseconds = Env::Default()->NowMicros(); - // Avoid clock skew. - if (end_microseconds < start_microseconds) return 0; - return end_microseconds - start_microseconds; - }(); auto log_and_count = [&](const string& status_str) { LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ") << " }; Status: " << status_str << ". Took " - << load_latency_microsecs << " microseconds."; + << GetLatencyMicroseconds(start_microseconds) << " microseconds."; load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); }; if (status.ok()) { @@ -292,7 +317,8 @@ Status LoadSavedModel(const SessionOptions& session_options, } else { log_and_count(kLoadAttemptFail); } - load_latency->GetCell(export_dir)->IncrementBy(load_latency_microsecs); + load_latency->GetCell(export_dir) + ->IncrementBy(GetLatencyMicroseconds(start_microseconds)); return status; } diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index 23e9dc40d23899b9cef168c9128b6d8ed1be3ee9..eeb910178902ca883ed211379ba3f188c139f92e 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -124,7 +124,9 @@ Status GetVariableNameToTensorMap( return Status::OK(); } std::vector variable_names; + variable_names.reserve(variable_names_set.size()); std::vector tensor_names; + tensor_names.reserve(variable_names_set.size()); for (const string& node_name : variable_names_set) { variable_names.push_back(node_name); NodeDef* node_def = name_to_node_map.at(node_name); diff --git a/tensorflow/compat_template.__init__.py b/tensorflow/compat_template.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2cf68c9cd8396987899b4f34f21b994b4722ead4 --- /dev/null +++ b/tensorflow/compat_template.__init__.py @@ -0,0 +1,56 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Bring in all of the public TensorFlow interface into this module.""" + +from __future__ import absolute_import as _absolute_import +from __future__ import division as _division +from __future__ import print_function as _print_function + +import os as _os +import sys as _sys + +# pylint: disable=g-bad-import-order + +# API IMPORTS PLACEHOLDER + +from tensorflow.python.tools import component_api_helper as _component_api_helper +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorboard.summary._tf.summary'), + error_msg=( + "Limited tf.compat.v2.summary API due to missing TensorBoard " + "installation")) +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=( + 'tensorflow_estimator.python.estimator.api._v2.estimator')) +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow.python.keras.api._v2.keras')) + +# We would like the following to work for fully enabling 2.0 in a 1.0 install: +# +# import tensorflow.compat.v2 as tf +# tf.enable_v2_behavior() +# +# This make this one symbol available directly. +from tensorflow.python.compat.v2_compat import enable_v2_behavior # pylint: disable=g-import-not-at-top + +# Add module aliases +_current_module = _sys.modules[__name__] +if hasattr(_current_module, 'keras'): + losses = keras.losses + metrics = keras.metrics + optimizers = keras.optimizers diff --git a/tensorflow/compat_template_v1.__init__.py b/tensorflow/compat_template_v1.__init__.py index 7df80ec01245a7fe820c79d5879458c4cd0a93cb..9549a71c41a0ba2aac58abd8cfb182aa4eaf3b4f 100644 --- a/tensorflow/compat_template_v1.__init__.py +++ b/tensorflow/compat_template_v1.__init__.py @@ -23,12 +23,15 @@ import os as _os # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, - child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) - -# API IMPORTS PLACEHOLDER - + child_package_str=( + 'tensorflow_estimator.python.estimator.api._v1.estimator')) +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow.python.keras.api._v1.keras')) from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top app.flags = flags # pylint: disable=undefined-variable diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 16151e77737429f4fbf690fc34b12a70bacebdc4..af016bf80e7a10d8729a1eb385466af48b5810cd 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -30,6 +30,7 @@ cc_library( "flags.h", ], deps = [ + ":aot_only_var_handle_op", ":embedded_protocol_buffers", "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:cpu_function_runtime", @@ -71,6 +72,7 @@ tf_cc_test( ":tfcompile_lib", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/strings", @@ -205,6 +207,15 @@ cc_library( ], ) +cc_library( + name = "aot_only_var_handle_op", + srcs = ["aot_only_var_handle_op.cc"], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + ], + alwayslink = 1, +) + tf_cc_test( name = "benchmark_test", srcs = ["benchmark_test.cc"], diff --git a/tensorflow/compiler/aot/aot_only_var_handle_op.cc b/tensorflow/compiler/aot/aot_only_var_handle_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ce36a979f424610a5aa952afa8db2245ed971a9 --- /dev/null +++ b/tensorflow/compiler/aot/aot_only_var_handle_op.cc @@ -0,0 +1,56 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +// Implementation of varhandle that binds a VarHandleOp to an XlaResource of the +// same name. It is not safe to use this op in a JIT context. +class XlaAotOnlyVarHandleOp : public XlaOpKernel { + public: + explicit XlaAotOnlyVarHandleOp(OpKernelConstruction* c); + void Compile(XlaOpKernelContext* context) override; + + private: + string name_; +}; + +XlaAotOnlyVarHandleOp::XlaAotOnlyVarHandleOp(OpKernelConstruction* c) + : XlaOpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("shared_name", &name_)); +} + +void XlaAotOnlyVarHandleOp::Compile(XlaOpKernelContext* context) { + // Look for a resource of the same name. TF also keys that on the container + // and type attributes, but that doesn't seem necessary. + for (const auto& resource : context->xla_context()->resources()) { + if (resource->kind() == XlaResource::kVariable && + resource->name() == name_) { + context->SetResourceOutput(0, resource.get()); + return; + } + } + context->SetStatus( + errors::InvalidArgument("Variable: ", name_, " not configured")); +} +} // namespace + +REGISTER_XLA_OP(Name("VarHandleOp").CompilationOnly(), XlaAotOnlyVarHandleOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index ab1c1be344e2257721507543bc7647d4ff4becb2..da0598736a7d6b7f55458d76ca30fa6ad46a74f9 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -129,7 +129,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type)); std::vector dim_vars; string dim_sizes, indices; - if (xla::ShapeUtil::Rank(shape) == 0 || + if (shape.rank() == 0 || (shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) { dim_sizes = "[1]"; indices = "[0]"; @@ -168,12 +168,12 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShapeProto& ps, const CompileResult& compile_result, string* methods) { size_t num_args = ps.parameters_size(); - if (config.feed_size() != num_args) { - return errors::InvalidArgument("mismatch between feed_size(", - config.feed_size(), ") and num_args(", - num_args, ")"); + if (config.feed_size() + config.variable_size() != num_args) { + return errors::InvalidArgument( + "mismatch between feed_size(", config.feed_size(), ")+variable_size(", + config.variable_size(), ") and num_args(", num_args, ")"); } - for (int i = 0; i < num_args; ++i) { + for (int i = 0; i < config.feed_size(); ++i) { std::vector> rewrites; TF_RETURN_IF_ERROR( AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites)); @@ -212,12 +212,14 @@ Status GenResultMethods(const tf2xla::Config& config, // tuple result, and we rely on this to simplify code generation. return errors::Internal("codegen requires the XLA result to be a tuple"); } - if (config.fetch_size() != ps.result().tuple_shapes_size()) { + size_t num_results = ps.result().tuple_shapes_size(); + if (config.fetch_size() + config.variable_size() != num_results) { return errors::InvalidArgument("mismatch between fetch_size(", - config.feed_size(), ") and tuple_size(", + config.fetch_size(), ")+variable_size(", + config.variable_size(), ") and tuple_size(", ps.result().tuple_shapes_size(), ")"); } - for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) { + for (int i = 0; i < config.fetch_size(); ++i) { std::vector> rewrites; TF_RETURN_IF_ERROR(AddRewritesForShape( i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites)); @@ -245,6 +247,51 @@ Status GenResultMethods(const tf2xla::Config& config, return Status::OK(); } +// Generate methods for variables. +Status GenVariableMethods(const tf2xla::Config& config, + const xla::ProgramShapeProto& ps, string* methods) { + size_t num_args = ps.parameters_size(); + for (int i = config.feed_size(); i < num_args; ++i) { + std::vector> rewrites; + TF_RETURN_IF_ERROR( + AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites)); + const string code = R"( + void set_var_{{NAME}}_input_data({{TYPE}}* data) { + set_arg_data({{I}}, data); + } +)"; + const tf2xla::Variable& var = config.variable(i - config.feed_size()); + *methods += RewriteWithName( + var.name().empty() ? var.node_name() : var.name(), code, rewrites); + } + size_t num_results = ps.result().tuple_shapes_size(); + for (int i = config.fetch_size(); i < num_results; ++i) { + std::vector> rewrites; + TF_RETURN_IF_ERROR(AddRewritesForShape( + i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites)); + string code = R"( + {{TYPE}}* var_{{NAME}}_result_data() { + return static_cast<{{TYPE}}*>(result_data({{I}})); + } + {{TYPE}}& var_{{NAME}}_result({{DIM_VARS}}) { + return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( + result_data({{I}}))){{INDICES}}; + } + const {{TYPE}}* var_{{NAME}}_result_data() const { + return static_cast(result_data({{I}})); + } + const {{TYPE}}& var_{{NAME}}_result({{DIM_VARS}}) const { + return (*static_cast( + result_data({{I}}))){{INDICES}}; + } +)"; + const tf2xla::Variable& var = config.variable(i - config.fetch_size()); + *methods += RewriteWithName( + var.name().empty() ? var.node_name() : var.name(), code, rewrites); + } + return Status::OK(); +} + // Generates code implementing {Arg,Result}Names(), where T is one of // tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string // literal in the array, with nullptr terminating the array. @@ -291,6 +338,14 @@ Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { TF_RETURN_IF_ERROR(ValidateCppIdent(fetch.name(), "fetch name")); } } + for (const tf2xla::Variable& variable : config.variable()) { + if (!variable.name().empty()) { + TF_RETURN_IF_ERROR(ValidateCppIdent(variable.name(), "variable name")); + } else { + TF_RETURN_IF_ERROR( + ValidateCppIdent(variable.node_name(), "variable name")); + } + } return Status::OK(); } @@ -339,9 +394,10 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, std::vector buffer_infos_for_temps = ExtractTempBufferInfos(buffer_infos); const xla::ProgramShapeProto& ps = compile_result.program_shape; - string methods_arg, methods_result; + string methods_arg, methods_result, methods_variable; TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); + TF_RETURN_IF_ERROR(GenVariableMethods(config, ps, &methods_variable)); const size_t arg_bytes_aligned = cpu_function_runtime::AlignedBufferBytes( buffer_infos_for_args.data(), buffer_infos_for_args.size(), /*allocate_entry_params=*/true); @@ -384,8 +440,9 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, // calling HloProfilePrinter::profile_counters_size. const string assign_profile_counters_size = opts.gen_hlo_profile_printer_data - ? "data->set_profile_counters_size(" - "data->hlo_profile_printer_data()->profile_counters_size());" + ? "set_static_data_profile_counters_size(data, " + "get_static_data_hlo_profile_printer_data(data)->" + "profile_counters_size());" : ""; // Use a poor-man's text templating mechanism; first populate the full header @@ -449,7 +506,7 @@ extern "C" void {{ENTRY}}( // arg bytes aligned: {{ARG_BYTES_ALIGNED}} // temp bytes total: {{TEMP_BYTES_TOTAL}} // temp bytes aligned: {{TEMP_BYTES_ALIGNED}} -class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { +class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = {{ARG_NUM}}; @@ -464,16 +521,17 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ XlaCompiledCpuFunction::StaticData* data = new XlaCompiledCpuFunction::StaticData; - data->set_raw_function({{ENTRY}}); - data->set_buffer_infos(BufferInfos()); - data->set_num_buffers(kNumBuffers); - data->set_arg_index_table(ArgIndexToBufferIndex()); - data->set_num_args(kNumArgs); - data->set_result_index(kResultIndex); - data->set_arg_names(StaticArgNames()); - data->set_result_names(StaticResultNames()); - data->set_program_shape(StaticProgramShape()); - data->set_hlo_profile_printer_data(StaticHloProfilePrinterData()); + set_static_data_raw_function(data, {{ENTRY}}); + set_static_data_buffer_infos(data, BufferInfos()); + set_static_data_num_buffers(data, kNumBuffers); + set_static_data_arg_index_table(data, ArgIndexToBufferIndex()); + set_static_data_num_args(data, kNumArgs); + set_static_data_result_index(data, kResultIndex); + set_static_data_arg_names(data, StaticArgNames()); + set_static_data_result_names(data, StaticResultNames()); + set_static_data_program_shape(data, StaticProgramShape()); + set_static_data_hlo_profile_printer_data( + data, StaticHloProfilePrinterData()); {{ASSIGN_PROFILE_COUNTERS_SIZE}} return data; }(); @@ -521,6 +579,21 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { // buffers are managed internally, and may change after each call to Run. {{METHODS_RESULT}} + // Methods for managing variable buffers. Buffers are in row-major order. The + // input and output buffers may or may not be identical. + // + // void set_var_X_data(T* data) + // Sets the buffer for variable X. + // + // T* var_X_data() + // Returns the buffer of type T for variable X. + // + // T& var_X(...dim indices...) + // Returns a reference to the value of type T for variable X, + // with dim indices specifying which value. No bounds checking is performed + // on dim indices. +{{METHODS_VARIABLE}} + private: // Number of buffers for the compiled computation. static constexpr size_t kNumBuffers = {{NUM_BUFFERS}}; @@ -587,6 +660,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { include_hlo_profile_printer_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, + {"{{METHODS_VARIABLE}}\n", methods_variable}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))}, diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index c1788ca32a1d099284eeb870f9513891051fd29e..5580e55b691bd10698b63d86bc0194b25da743b9 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -22,6 +22,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" @@ -172,6 +174,15 @@ TEST(CodegenTest, Golden) { tf2xla::Fetch* fetch = config.add_fetch(); fetch->mutable_id()->set_node_name("fetch0"); fetch->set_name("myfetch"); + tf2xla::Variable* variable = config.add_variable(); + variable->set_node_name("myvar"); + variable->mutable_shape()->add_dim()->set_size(1); + variable->set_type(DT_FLOAT); + tf2xla::Variable* variable2 = config.add_variable(); + variable2->set_node_name("my/var"); + variable2->set_name("myvar2"); + variable2->mutable_shape()->add_dim()->set_size(5); + variable2->set_type(DT_INT32); CompileResult compile_result; compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult( {}, @@ -186,9 +197,14 @@ TEST(CodegenTest, Golden) { { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), + xla::ShapeUtil::MakeShape(xla::F32, {1}), + xla::ShapeUtil::MakeShape(xla::S32, {5}), }, - xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})) + xla::ShapeUtil::MakeTupleShape({ + xla::ShapeUtil::MakeShape(xla::U32, {5, 6}), + xla::ShapeUtil::MakeShape(xla::F32, {1}), + xla::ShapeUtil::MakeShape(xla::S32, {5}), + })) .ToProto(); compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 968afad65ed6d4b5510687df484b7ce6743f6a85..b5f33d690d492489e9090786cd341e035ae7ca15 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -52,14 +52,14 @@ namespace bar { // is guaranteed that no thread may call a non-const method. // // The logical function signature is: -// ((unknown): f32[1,2], (unknown): s64[3,4]) -> (u32[5,6]) +// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5]) // // Memory stats: // arg bytes total: 104 // arg bytes aligned: 192 // temp bytes total: 126 // temp bytes aligned: 320 -class MyClass : public tensorflow::XlaCompiledCpuFunction { +class MyClass final : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = 2; @@ -74,16 +74,17 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ XlaCompiledCpuFunction::StaticData* data = new XlaCompiledCpuFunction::StaticData; - data->set_raw_function(entry_point); - data->set_buffer_infos(BufferInfos()); - data->set_num_buffers(kNumBuffers); - data->set_arg_index_table(ArgIndexToBufferIndex()); - data->set_num_args(kNumArgs); - data->set_result_index(kResultIndex); - data->set_arg_names(StaticArgNames()); - data->set_result_names(StaticResultNames()); - data->set_program_shape(StaticProgramShape()); - data->set_hlo_profile_printer_data(StaticHloProfilePrinterData()); + set_static_data_raw_function(data, entry_point); + set_static_data_buffer_infos(data, BufferInfos()); + set_static_data_num_buffers(data, kNumBuffers); + set_static_data_arg_index_table(data, ArgIndexToBufferIndex()); + set_static_data_num_args(data, kNumArgs); + set_static_data_result_index(data, kResultIndex); + set_static_data_arg_names(data, StaticArgNames()); + set_static_data_result_names(data, StaticResultNames()); + set_static_data_program_shape(data, StaticProgramShape()); + set_static_data_hlo_profile_printer_data( + data, StaticHloProfilePrinterData()); return data; }(); @@ -213,6 +214,58 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { result_data(0)))[dim0][dim1]; } + // Methods for managing variable buffers. Buffers are in row-major order. The + // input and output buffers may or may not be identical. + // + // void set_var_X_data(T* data) + // Sets the buffer for variable X. + // + // T* var_X_data() + // Returns the buffer of type T for variable X. + // + // T& var_X(...dim indices...) + // Returns a reference to the value of type T for variable X, + // with dim indices specifying which value. No bounds checking is performed + // on dim indices. + + void set_var_myvar_input_data(float* data) { + set_arg_data(2, data); + } + + void set_var_myvar2_input_data(tensorflow::int32* data) { + set_arg_data(3, data); + } + + float* var_myvar_result_data() { + return static_cast(result_data(1)); + } + float& var_myvar_result() { + return (*static_cast( + result_data(1)))[0]; + } + const float* var_myvar_result_data() const { + return static_cast(result_data(1)); + } + const float& var_myvar_result() const { + return (*static_cast( + result_data(1)))[0]; + } + + tensorflow::int32* var_myvar2_result_data() { + return static_cast(result_data(2)); + } + tensorflow::int32& var_myvar2_result(size_t dim0) { + return (*static_cast( + result_data(2)))[dim0]; + } + const tensorflow::int32* var_myvar2_result_data() const { + return static_cast(result_data(2)); + } + const tensorflow::int32& var_myvar2_result(size_t dim0) const { + return (*static_cast( + result_data(2)))[dim0]; + } + private: // Number of buffers for the compiled computation. static constexpr size_t kNumBuffers = 6; @@ -256,7 +309,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { static const xla::ProgramShapeProto* StaticProgramShape() { static const xla::ProgramShapeProto* kShape = []() { xla::ProgramShapeProto* proto = new xla::ProgramShapeProto; - proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52); + proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 132); return proto; }(); return kShape; diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden index ce8e5ec8c96a2c3696f14b8eea206d648182ecb5..2884597abcf29583e6192296b0e4ce6825d7c01a 100644 Binary files a/tensorflow/compiler/aot/codegen_test_o.golden and b/tensorflow/compiler/aot/codegen_test_o.golden differ diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 9fc223bdc7c0e207ce2005cb86250aa77e709df8..0e46a9f5e9d68fa2174f7bd9b9fa7c3a82dfb715 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -108,10 +108,13 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, computation.Snapshot()); // Serialize the HloSnapshot deterministically so that all the outputs of a // tf_library genrule are deterministic. - string proto; - TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto)); + const size_t size = module->ByteSizeLong(); + auto serialized = absl::make_unique(size); + TF_RET_CHECK( + SerializeToBufferDeterministic(*module, serialized.get(), size)); TF_RETURN_IF_ERROR( - WriteStringToFile(Env::Default(), flags.out_session_module, proto)); + WriteStringToFile(Env::Default(), flags.out_session_module, + absl::string_view(serialized.get(), size))); } xla::cpu::CpuAotCompilationOptions aot_opts( flags.target_triple, flags.target_cpu, flags.target_features, diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 10fa33ab5e84dcbc1629bee6214e8969046f19c2..444264ba6e1f59c33551796025ba845c62c02d43 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -69,6 +69,7 @@ genrule( "test_graph_tfmatmulandadd.pb", "test_graph_tfsplits.pb", "test_graph_tftop_k.pb", + "test_graph_tfvariable.pb", ], # Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any # GPUs which might be present. This is important because builds may run @@ -222,6 +223,17 @@ tf_library( ], ) +tf_library( + name = "test_graph_tfvariable", + testonly = 1, + config = "test_graph_tfvariable.config.pbtxt", + cpp_class = "VariableComp", + graph = "test_graph_tfvariable.pb", + tags = [ + "manual", + ], +) + tf_cc_test( name = "tfcompile_test", srcs = ["tfcompile_test.cc"], @@ -241,6 +253,7 @@ tf_cc_test( ":test_graph_tfmatmulandadd_with_profiling", ":test_graph_tfsplits", ":test_graph_tftop_k", + ":test_graph_tfvariable", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 64b861a73091642b03573543a5c55618bf33915d..42f8812def0503824416d92daa2db71a64c3db88 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -50,7 +50,7 @@ def tfadd_with_ckpt(out_dir): y = variables.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') - init_op = variables.initialize_all_variables() + init_op = variables.global_variables_initializer() saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) with session.Session() as sess: sess.run(init_op) @@ -65,7 +65,7 @@ def tfadd_with_ckpt_saver(out_dir): y = variables.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') - init_op = variables.initialize_all_variables() + init_op = variables.global_variables_initializer() saver = saver_lib.Saver(name='abcprefix', write_version=saver_pb2.SaverDef.V1) with session.Session() as sess: sess.run(init_op) @@ -149,6 +149,14 @@ def tftop_k(_): array_ops.identity(output[1], name='indices') +def tfvariable(_): + x = variables.Variable(1000.0, name='x') + old_x = x.value() + with ops.control_dependencies([old_x]): + new_x = x.assign_add(42.0) + array_ops.stack([old_x, new_x], name='result') + + def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -171,6 +179,7 @@ def main(_): write_graph(tfmatmulandadd, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir) write_graph(tftop_k, FLAGS.out_dir) + write_graph(tfvariable, FLAGS.out_dir) if __name__ == '__main__': diff --git a/tensorflow/compiler/aot/tests/test_graph_tfvariable.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfvariable.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9b4c4215a330b014f595edde001aba73ad7d8263 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfvariable.config.pbtxt @@ -0,0 +1,12 @@ +# Text form of tensorflow.tf2xla.Config proto. +fetch { + id { node_name: "result" } +} + +variable { + node_name: "x" + shape { + dim { size: 1 } + } + type: DT_FLOAT +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 4dd79e5882d7da61be029735ef2b165908c599f9..5f9316f3933713e12fc5960b9adfecc6e9bd99b5 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" #include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h" #include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -473,6 +474,28 @@ TEST(TFCompileTest, TopK) { EXPECT_EQ(expected_indices[1], fn.result1(1)); } +TEST(TFCompileTest, Variable) { + Eigen::ThreadPool tp(1); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + VariableComp fn; + float x = 23; + fn.set_var_x_input_data(&x); + + fn.set_thread_pool(&device); + fn.Run(); + EXPECT_EQ(fn.result0(0, 0), 23); + EXPECT_EQ(fn.result0(1, 0), 65); + EXPECT_EQ(fn.var_x_result(), 65); + + EXPECT_EQ(x, 23); + x = fn.var_x_result(); + fn.Run(); + EXPECT_EQ(fn.result0(0, 0), 65); + EXPECT_EQ(fn.result0(1, 0), 107); + EXPECT_EQ(fn.var_x_result(), 107); +} + TEST(TFCompileTest, AssertEqAndReturnDiff) { // Assert is converted into a no-op in XLA, so there is no failure even if the // two args are different. diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 2dc3e8c9113b37bf9d575ad66783f4ab49478af4..2abe3e29b78dbbe719637b13418704acc213d050 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -207,7 +207,7 @@ def tf_library( # # Note that setting the local=1 attribute on a *test target* causes the # test infrastructure to skip that test. However this is a genrule, not - # a test target, and runs with --genrule_strategy=forced_forge, meaning + # a test target, and runs with --strategy=Genrule=forced_forge, meaning # the local=1 attribute is ignored, and the genrule is still run. # # https://www.bazel.io/versions/master/docs/be/general.html#genrule @@ -283,7 +283,7 @@ def tf_library( ) # Variables used for gen_test and gen_benchmark. - cpp_class_split = cpp_class.rsplit("::", maxsplit = 2) + cpp_class_split = cpp_class.rsplit("::", 2) if len(cpp_class_split) == 1: no_ns_name = cpp_class_split[0] else: diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index d548de8c44285f6d21dd778db464a31e1b19645b..0b6ab7e723d6e3a55da2f1c30b75f44cbdaa75bb 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -136,6 +136,10 @@ int main(int argc, char** argv) { tensorflow::string usage = tensorflow::tfcompile::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); + if (argc > 1 && absl::string_view(argv[1]) == "--help") { + std::cerr << usage << "\n"; + return 0; + } bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); QCHECK(parsed_flags_ok) << "\n" << usage; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 15dcbb2641eca031e82db9aa58dee6a14ab0a2cc..121de401cefb2b56b984944dde769f226590dc67 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -175,12 +175,22 @@ cc_library( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:stream_pool", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:resource_variable_ops_op_lib", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", @@ -198,9 +208,11 @@ cc_library( "//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels/data:generator_dataset_op", "//tensorflow/core/kernels/data:iterator_ops", + "//tensorflow/core/kernels/data:optional_ops", "//tensorflow/core/kernels/data:prefetch_dataset_op", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", ], ) @@ -271,7 +283,6 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", @@ -454,7 +465,6 @@ cc_library( "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/core:framework", "//tensorflow/core:graph", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -515,6 +525,7 @@ cc_library( "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", @@ -613,6 +624,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:scope", @@ -625,15 +637,16 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla:test", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - "//tensorflow/core/grappler/optimizers/data:graph_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 9f4042630edaec1b9519b6434d859a48372e8b15..285b1efa53d91922c9fa161cfd2de34e1434d0c4 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -115,6 +115,13 @@ void MergeOutgoingControlEdges(const Scope& s, Node* old_node, Node* new_node) { return; } + if (ctrl_edges.size() == 1 && ctrl_edges.front()->dst()->IsSink()) { + // Avoid creating a Merge node if we can just add an edge to _SINK + // instead. + s.graph()->AddControlEdge(new_node, s.graph()->sink_node()); + return; + } + // We can't merge control edges directly so we instead first "convert" them to // normal values that can be merged, merge the values and then "convert" the // merged value back into control. diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index 48a23a4c1711ac88a329723c46559112d5a39dbd..c14c7465c55b7d350d6b3a6853cef6692140ce78 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/jit/node_matchers.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -69,6 +68,8 @@ Status BuildXlaOps(const Scope& s, std::unique_ptr* result) { } } + FixupSourceAndSinkEdges(graph.get()); + GraphOptimizationPassOptions opt_options; opt_options.graph = &graph; BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true); @@ -224,5 +225,23 @@ TEST_F(BuildXlaOpsTest, OnXlaDevice) { ASSERT_NE(write_op_new, nullptr); EXPECT_THAT(write_op_new, assign_var); } + +TEST_F(BuildXlaOpsTest, NoExtraMergeForEdgeToSink) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + Node* call; + TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + + std::unique_ptr graph; + TF_ASSERT_OK(BuildXlaOps(root, &graph)); + + Node* sink_node = graph->sink_node(); + EXPECT_THAT(sink_node, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")), + NodeWith(Op("cluster_0")), + NodeWith(Op("NoOp"))))); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 0562838f628c66b1eb03af9d2a5139c01dca31c5..4397eea9af266cbd0392f08323e59077c9395150 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -20,7 +20,10 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/hash/hash.h" @@ -110,7 +113,11 @@ class Predicate { enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol }; virtual string ToString() const = 0; - int64 hash() const { return hash_; } + + // An ID assigned to the Predicate at construction time. Conceptually like a + // pointer, except that it is stable across runs. + int64 id() const { return id_; } + virtual absl::Span GetOperands() const = 0; virtual Kind kind() const = 0; @@ -123,29 +130,19 @@ class Predicate { static void Visit(Predicate* p, const FunctionTy& func); protected: - explicit Predicate(int64 hash) : hash_(hash) {} + explicit Predicate(int64 id) : id_(id) {} private: - const int64 hash_; + const int64 id_; TF_DISALLOW_COPY_AND_ASSIGN(Predicate); }; -int64 HashPredicateSequence(Predicate::Kind kind, - absl::Span preds) { - int64 hash = ::tensorflow::hash()(kind); - for (Predicate* pred : preds) { - hash = Hash64Combine(hash, pred->hash()); - } - return hash; -} - // Represents a logical conjunction of a set of predicates. class AndPredicate : public Predicate { public: - explicit AndPredicate(std::vector operands) - : Predicate(HashPredicateSequence(Kind::kAnd, operands)), - operands_(std::move(operands)) {} + explicit AndPredicate(int64 id, std::vector operands) + : Predicate(id), operands_(std::move(operands)) {} string ToString() const override { if (operands().empty()) { @@ -174,9 +171,8 @@ class AndPredicate : public Predicate { // Represents a logical disjunction of a set of predicates. class OrPredicate : public Predicate { public: - explicit OrPredicate(std::vector operands) - : Predicate(HashPredicateSequence(Kind::kOr, operands)), - operands_(std::move(operands)) {} + explicit OrPredicate(int64 id, std::vector operands) + : Predicate(id), operands_(std::move(operands)) {} string ToString() const override { if (operands().empty()) { @@ -204,9 +200,8 @@ class OrPredicate : public Predicate { // Represents a logical negation of a set of predicates. class NotPredicate : public Predicate { public: - explicit NotPredicate(Predicate* operand) - : Predicate(HashPredicateSequence(Kind::kNot, {operand})), - operands_({operand}) {} + explicit NotPredicate(int64 id, Predicate* operand) + : Predicate(id), operands_({operand}) {} string ToString() const override { return absl::StrCat("~", operand()->ToString()); @@ -222,29 +217,38 @@ class NotPredicate : public Predicate { std::array operands_; }; -// Represents an infinite list of predicates. +// Represents the liveness of an induction variable. For users inside the loop +// this represents the "current" liveness of the induction variable. For users +// outside the loop it represents the "last" liveness of the induction variable. +// +// More concretely, an and recurrence {S,&,X} represents the liveness of V +// in the following graph: // -// An AndRecurrence with start = S and step = X is printed as {S,&,X} and stands -// for the list of predicates: +// V = Merge(S', V_NextIt) +// V = Op(V, X') +// V_NextIt = NextIteration(V) // -// S, S & GenSym(X,1), S & GenSym(X,1) & GenSym(X,2), ... +// where Predicate(S') = S and Predicate(X') = X. // -// where GenSym(, ) renames every SymbolPredicate in -// by appending to it, in effect creating a "fresh" symbol. -// This means {P,&,Q} is not equal to "P on the first iteration; P&Q on -// subsequent iterations". +// `X` may contain symbolic predicates and the operations corresponding to these +// symbolic predicates are either in frame `loop` or outside it. The symbols +// that are inside frame `loop` are loop variant (i.e. can have different +// liveness in each loop iteration) and the symbols that are outside frame +// `loop` are loop invariant (i.e. have the same liveness across all +// iterations). class AndRecurrencePredicate : public Predicate { public: - explicit AndRecurrencePredicate(Predicate* start, Predicate* step) - : Predicate(HashPredicateSequence(Kind::kAndRecurrence, {start, step})), - operands_({start, step}) {} + explicit AndRecurrencePredicate(int64 id, Predicate* start, Predicate* step, + std::vector frame) + : Predicate(id), operands_({start, step}), frame_(std::move(frame)) {} Predicate* start() const { return operands_[0]; } Predicate* step() const { return operands_[1]; } + absl::Span frame() const { return frame_; } string ToString() const override { return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(), - "}"); + "}<", absl::StrJoin(frame(), ";"), ">"); } Kind kind() const override { return Kind::kAndRecurrence; } @@ -255,6 +259,7 @@ class AndRecurrencePredicate : public Predicate { private: std::array operands_; + std::vector frame_; }; // Represents an uninterpreted symbol in a logical predicate. @@ -264,8 +269,8 @@ class AndRecurrencePredicate : public Predicate { // symbols. class SymbolPredicate : public Predicate { public: - explicit SymbolPredicate(TensorId tensor_id, bool must_be_true) - : Predicate(Hash(tensor_id, must_be_true)), + explicit SymbolPredicate(int64 id, TensorId tensor_id, bool must_be_true) + : Predicate(id), tensor_id_(std::move(tensor_id)), must_be_true_(must_be_true) {} @@ -281,20 +286,13 @@ class SymbolPredicate : public Predicate { // "tensor_id() is live and evaluates to true". // // If `must_be_true()` is false then this SymbolPredicate represents the - // proposition "tensor_id() is live (and may evalutate to any value)" + // proposition "tensor_id() is live (and may evaluate to any value)" TensorId tensor_id() const { return tensor_id_; } bool must_be_true() const { return must_be_true_; } private: TensorId tensor_id_; bool must_be_true_; - - static int64 Hash(const TensorId tensor_id, bool must_be_true) { - return Hash64Combine( - ::tensorflow::hash()(must_be_true), - Hash64Combine(::tensorflow::hash()(Kind::kSymbol), - TensorId::Hasher{}(tensor_id))); - } }; template @@ -333,34 +331,58 @@ class PredicateFactory { } Predicate* MakeNotPredicate(Predicate* pred) { - SignatureForNot signature = pred; - auto it = interned_not_instances_.find(signature); - if (it == interned_not_instances_.end()) { - std::unique_ptr new_pred = Make(pred); - Predicate* new_pred_ptr = new_pred.get(); - interned_not_instances_.emplace(signature, std::move(new_pred)); - return new_pred_ptr; - } else { - return it->second.get(); + auto it = make_not_predicate_cache_.find(pred); + if (it != make_not_predicate_cache_.end()) { + return it->second; } + + Predicate* result = MakeNotPredicateImpl(pred); + + bool insert_successful = + make_not_predicate_cache_.insert({pred, result}).second; + (void)insert_successful; + DCHECK(insert_successful); + + return result; } - Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step) { - auto it = interned_and_rec_instances_.find({start, step}); + Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step, + std::vector frame) { + SignatureForAndRec signature(start, step, std::move(frame)); + auto it = interned_and_rec_instances_.find(signature); if (it != interned_and_rec_instances_.end()) { return it->second.get(); } - std::unique_ptr new_pred = - Make(start, step); + std::unique_ptr new_pred = Make( + std::get<0>(signature), std::get<1>(signature), std::get<2>(signature)); Predicate* new_pred_ptr = new_pred.get(); - CHECK(interned_and_rec_instances_ - .emplace(SignatureForAndRec(start, step), std::move(new_pred)) - .second); + bool inserted = + interned_and_rec_instances_.emplace(signature, std::move(new_pred)) + .second; + (void)inserted; + DCHECK(inserted); return new_pred_ptr; } - Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) { + Status MakeSymbolPredicate(Node* node, int output_idx, bool must_be_true, + Predicate** predicate) { + TensorId tensor_id(node->name(), output_idx); + + bool is_boolean_tensor = node->output_type(tensor_id.index()) == DT_BOOL; + TF_RET_CHECK(!must_be_true || is_boolean_tensor); + + if (node->type_string() == "Const" && must_be_true) { + const TensorProto* proto = nullptr; + TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "value", &proto)); + + Tensor tensor(proto->dtype()); + TF_RET_CHECK(tensor.FromProto(*proto)); + + *predicate = tensor.scalar()() ? MakeTrue() : MakeFalse(); + return Status::OK(); + } + SignatureForSymbol signature = {tensor_id, must_be_true}; auto it = interned_symbol_instances_.find(signature); if (it == interned_symbol_instances_.end()) { @@ -369,20 +391,70 @@ class PredicateFactory { Predicate* new_pred_ptr = new_pred.get(); interned_symbol_instances_.emplace(std::move(signature), std::move(new_pred)); - return new_pred_ptr; + *predicate = new_pred_ptr; } else { - return it->second.get(); + *predicate = it->second.get(); } + + return Status::OK(); } Predicate* MakeTrue() { return MakeAndPredicate({}); } Predicate* MakeFalse() { return MakeOrPredicate({}); } + ~PredicateFactory() { + DCHECK_EQ(stack_depth_, 0) << "Unnested IncrementStackDepth?"; + } + private: + Predicate* MakeNotPredicateImpl(Predicate* pred) { + IncrementStackDepth stack_frame(this); + if (!stack_frame.HasOverflowed()) { + if (Predicate* simplified = SimplifyUsingDeMorgan(pred)) { + return simplified; + } + + // ~~A => A + if (auto* not_pred = dynamic_cast(pred)) { + return not_pred->operand(); + } + } + + SignatureForNot signature = pred; + auto it = interned_not_instances_.find(signature); + if (it == interned_not_instances_.end()) { + std::unique_ptr new_pred = Make(pred); + Predicate* new_pred_ptr = new_pred.get(); + interned_not_instances_.emplace(signature, std::move(new_pred)); + return new_pred_ptr; + } else { + return it->second.get(); + } + } + + Predicate* SimplifyUsingDeMorgan(Predicate* pred) { + // ~(A & B & C & ...) => ~A | ~B | ~C | ~... + // ~(A | B | C | ...) -> ~A & ~B & ~C & ~... + Predicate::Kind kind = pred->kind(); + + if (kind == Predicate::Kind::kAnd || kind == Predicate::Kind::kOr) { + std::vector new_operands; + absl::c_transform(pred->GetOperands(), std::back_inserter(new_operands), + [&](Predicate* p) { return MakeNotPredicate(p); }); + return kind == Predicate::Kind::kOr ? MakeAndPredicate(new_operands) + : MakeOrPredicate(new_operands); + } + + return nullptr; + } + template std::unique_ptr Make(Args&&... args) { + // If we ever expose the Predicate class outside this .cc file then we may + // want to make this hard to misuse (by accidentally passing in an arbitrary + // integer to the Predicate constructor for instance). return std::unique_ptr( - new PredicateT(std::forward(args)...)); + new PredicateT(id_counter_++, std::forward(args)...)); } Predicate* MakeAndOrImpl(absl::Span operands, bool is_and); @@ -402,7 +474,8 @@ class PredicateFactory { using SignatureForAndOr = std::pair>; using SignatureForNot = Predicate*; - using SignatureForAndRec = std::pair; + using SignatureForAndRec = + std::tuple>; using SignatureForSymbol = std::pair; struct HashSignatureForAndOr { @@ -422,6 +495,36 @@ class PredicateFactory { } }; + // Used to limit recursion to avoid blowing up the stack and cap compile time. + class IncrementStackDepth { + public: + explicit IncrementStackDepth(PredicateFactory* parent) : parent_(parent) { + parent_->stack_depth_++; + } + + bool HasOverflowed() const { + const int kMaxStackDepth = 8; + return parent_->stack_depth_ >= kMaxStackDepth; + } + + ~IncrementStackDepth() { parent_->stack_depth_--; } + + private: + PredicateFactory* parent_; + }; + + // A cache for the MakeNotPredicate function. + // + // NB! This is *not* the same as `interned_not_instances_`. + // `interned_not_instances_` maps ensures pointer identity for `NotPredicate` + // instances, i.e., it ensures there at most one instance of Not(predicate) + // for any given predicate whereas `make_not_predicate_cache_` simply caches + // the result of the `MakeNotPredicate` function. The values in + // `interned_not_instances_` are always instance of `NotPredicate` whereas the + // values in `make_not_predicate_cache_` may not be (for instance it will map + // Not(Not(A)) to A). + absl::flat_hash_map make_not_predicate_cache_; + absl::flat_hash_map, HashSignatureForAndOr> interned_and_or_instances_; @@ -432,13 +535,15 @@ class PredicateFactory { absl::flat_hash_map, HashSignatureForSymbol> interned_symbol_instances_; + int64 id_counter_ = 0; + int stack_depth_ = 0; }; Predicate* PredicateFactory::MakeInternedAndOr( std::vector simplified_ops, Predicate::Kind pred_kind) { std::stable_sort( simplified_ops.begin(), simplified_ops.end(), - [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + [](Predicate* a, Predicate* b) { return a->id() < b->id(); }); auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); if (it != interned_and_or_instances_.end()) { @@ -466,6 +571,13 @@ Predicate* PredicateFactory::MakeAndOrImpl( absl::Span operands, bool is_and) { Predicate::Kind pred_kind = is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; + + IncrementStackDepth stack_frame(this); + if (stack_frame.HasOverflowed()) { + return MakeInternedAndOr( + std::vector(operands.begin(), operands.end()), pred_kind); + } + Predicate::Kind other_pred_kind = is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd; absl::flat_hash_set simplified_ops_set; @@ -494,16 +606,31 @@ Predicate* PredicateFactory::MakeAndOrImpl( // Simplify "A&~A=>False" and "A|~A=>True". absl::flat_hash_set negated_ops; - for (Predicate* op : simplified_ops) { - if (op->kind() == Predicate::Kind::kNot) { - negated_ops.insert(dynamic_cast(*op).operand()); - } - } - for (Predicate* op : simplified_ops) { if (negated_ops.count(op)) { + // Simple case: + // + // A & ~A & ... == False + // A | ~A | ... == True return is_and ? MakeFalse() : MakeTrue(); } + + Predicate* negated_op = MakeNotPredicate(op); + if (negated_op->kind() == pred_kind) { + // Slightly more complicated case: + // + // (~A | ~B | ~C) & A & B & C & ... == + // ~(A & B & C) & (A & B & C) & ... == False + // + // (~A & ~B & ~C) | A | B | C | ... == + // ~(A | B | C) | (A | B | C) | ... == True + if (absl::c_all_of(negated_op->GetOperands(), [&](Predicate* p) { + return simplified_ops_set.contains(p); + })) { + return is_and ? MakeFalse() : MakeTrue(); + } + } + negated_ops.insert(negated_op); } // If all ops contain the same subop, then factor it out thanks to the @@ -619,6 +746,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { const Graph& graph_; absl::flat_hash_map predicate_map_; PredicateFactory predicate_factory_; + std::vector control_flow_info_; bool vlog_; }; @@ -661,9 +789,12 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n, TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); const Edge* pred_edge; TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge)); - Predicate* true_switch = predicate_factory_.MakeSymbolPredicate( - TensorId(pred_edge->src()->name(), pred_edge->src_output()), - /*must_be_true=*/true); + + Predicate* true_switch; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + pred_edge->src(), pred_edge->src_output(), + /*must_be_true=*/true, &true_switch)); + Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch); // Output 0 is alive iff all inputs are alive and the condition is false. @@ -761,6 +892,23 @@ Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory, return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr; } + +Status GetFullFrame(const Node* n, absl::Span cfi_infos, + std::vector* frame) { + int depth = 0; + for (const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; !n->IsSource(); + n = cfi_iter->parent_frame, cfi_iter = &cfi_infos[n->id()]) { + frame->push_back(cfi_iter->frame_name); + + if (depth++ > 5000) { + return errors::Internal( + "Frame of depth > 5000: Probably malformed graph or a bug in " + "BuildControlFlowInfo"); + } + } + + return Status::OK(); +} } // namespace Status DeadnessAnalysisImpl::HandleMerge(Node* n, @@ -783,8 +931,10 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, if (has_unvisited_backedge) { // We're visiting this merge for the first time and it has an unvisited // backedge. - Predicate* input_data_pred = predicate_factory_.MakeSymbolPredicate( - TensorId(n->name(), 0), /*must_be_true=*/false); + Predicate* input_data_pred; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred)); + SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); return Status::OK(); @@ -825,8 +975,10 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, Predicate* start = predicate_factory_.MakeOrPredicate(non_recurrent_inputs); - Predicate* and_rec = - predicate_factory_.MakeAndRecurrencePredicate(start, step); + std::vector frame; + TF_RETURN_IF_ERROR(GetFullFrame(n, control_flow_info_, &frame)); + Predicate* and_rec = predicate_factory_.MakeAndRecurrencePredicate( + start, step, std::move(frame)); SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit); return Status::OK(); } @@ -841,8 +993,10 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, // acquire a dead signal from a _Send. std::vector input_preds; TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); - input_preds.push_back(predicate_factory_.MakeSymbolPredicate( - TensorId(n->name(), 0), /*must_be_true=*/false)); + Predicate* signal_is_alive; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + n, /*output_idx=*/0, /*must_be_true=*/false, &signal_is_alive)); + input_preds.push_back(signal_is_alive); SetPredicate(n, {0, Graph::kControlSlot}, predicate_factory_.MakeAndPredicate(input_preds), should_revisit); @@ -892,6 +1046,24 @@ Status DeadnessAnalysisImpl::Populate() { Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( absl::Span rpo) { + std::vector unreachable_nodes; + // Compute the loop structure of the graph. + TF_RETURN_IF_ERROR( + BuildControlFlowInfo(&graph_, &control_flow_info_, &unreachable_nodes)); + + // Do some opportunistic error checking: + if (!unreachable_nodes.empty()) { + if (unreachable_nodes.size() > 5) { + unreachable_nodes.erase(unreachable_nodes.begin() + 5, + unreachable_nodes.end()); + } + + return errors::InvalidArgument( + "Found unreachable nodes, most likely source and sink nodes not " + "connected: ", + absl::StrJoin(unreachable_nodes, ", ")); + } + // This an abstract interpretation over the deadness propagation semantics of // the graph executor. // diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 8a73101c184e6190921fd7729742922bd96f4bcf..38a5118d9a721b814e1b52ce4202d4fb783e3ac3 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -123,10 +123,9 @@ InductionVarInfo CreateInductionVariable(const Scope& root, Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1); Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10); Output loop_cond_expr = - ops::Less(root.WithOpName(prefix + "/less"), iv.output, final_value); - Output loop_cond = - ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); - ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); + ops::Less(root.WithOpName(prefix + "/cond"), iv.output, final_value); + ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, + loop_cond_expr); ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), latch.output_false); Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"), @@ -140,7 +139,7 @@ InductionVarInfo CreateInductionVariable(const Scope& root, root.graph()->AddControlEdge(iv.output.node(), increment_by.node()); root.graph()->AddControlEdge(iv.output.node(), final_value.node()); - return {iv.output, loop_cond}; + return {iv.output, loop_cond_expr}; } InductionVarInfo CreateInductionVariable(const Scope& root, @@ -515,24 +514,27 @@ TEST(DeadnessAnalysisTest, Loop) { // In theory we should be able to tell that iv0/cond:0 and iv1/cond:0 // produce the same deadness. But we're not that smart today. - EXPECT_EQ(predicate_map[ControlOutputFor(iv0)], "{#true,&,*iv0/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv1)], "{#true,&,*iv1/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv2)], "{#true,&,*iv2/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv0)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv1)], + "{#true,&,*iv1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv2)], + "{#true,&,*iv2/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "({#true,&,*iv1/cond:0} & {#true,&,*iv0/cond:0})"); + "({#true,&,*iv0/cond:0} & {#true,&,*iv1/cond:0})"); EXPECT_EQ(predicate_map[ControlOutputFor(add1)], - "({#true,&,*iv1/cond:0} & {#true,&,*iv2/cond:0})"); + "({#true,&,*iv1/cond:0} & {#true,&,*iv2/cond:0})"); } } TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { Scope root = Scope::NewRootScope().ExitOnError(); - InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0); + InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0); Output dependent_iv0 = - CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0) .induction_var; Output dependent_iv1 = - CreateDependentLoopInvariantValue(root, "div1", "frame", iv.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0) .induction_var; Output add0 = ops::Add(root.WithOpName("add0"), dependent_iv0, dependent_iv1); @@ -549,13 +551,13 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], - "{#true,&,*iv0/cond:0}"); + "{#true,&,*iv0/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(iv0/iv:0 & *iv0/cond:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(iv0/iv:0 & *iv0/cond:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(iv0/iv:0 & *iv0/cond:0)}"); } } @@ -595,32 +597,33 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { Scope root = Scope::NewRootScope().ExitOnError(); InductionVarInfo iv_outer = - CreateInductionVariable(root, "iv_outer", "frame", 0); + CreateInductionVariable(root, "iv_outer", "outer_loop", 0); + Output enter_constant_outer_loop = ops::internal::Enter( + root.WithOpName("constant_enter_outer_loop"), + ops::Const(root.WithOpName("constant"), 5), "outer_loop", + ops::internal::Enter::Attrs().IsConstant(true)); ops::Switch inner_value(root.WithOpName("outer_is_live"), - ops::Const(root.WithOpName("constant"), 5), - iv_outer.loop_cond); + enter_constant_outer_loop, iv_outer.loop_cond); InductionVarInfo iv_inner = CreateInductionVariable( - root, "iv_inner", "frame", - ops::internal::Enter(root.WithOpName("iv_inner/enter"), - inner_value.output_true, "frame_inner")); + root, "iv_inner", "inner_loop", inner_value.output_true); Output dependent_outer_iv0 = - CreateDependentLoopInvariantValue(root, "dependent_outer_iv0", "frame", - iv_outer.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "dependent_outer_iv0", + "outer_loop", iv_outer.loop_cond, 0) .induction_var; Output dependent_outer_iv1 = - CreateDependentLoopInvariantValue(root, "dependent_outer_iv1", "frame", - iv_outer.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "dependent_outer_iv1", + "outer_loop", iv_outer.loop_cond, 0) .induction_var; - Output dependent_inner_iv0 = - CreateDependentLoopInvariantValue(root, "dependent_inner_iv0", "frame", - iv_inner.loop_cond, dependent_outer_iv0) - .induction_var; - Output dependent_inner_iv1 = - CreateDependentLoopInvariantValue(root, "dependent_inner_iv1", "frame", - iv_inner.loop_cond, dependent_outer_iv1) - .induction_var; + Output dependent_inner_iv0 = CreateDependentLoopInvariantValue( + root, "dependent_inner_iv0", "inner_loop", + iv_inner.loop_cond, dependent_outer_iv0) + .induction_var; + Output dependent_inner_iv1 = CreateDependentLoopInvariantValue( + root, "dependent_inner_iv1", "inner_loop", + iv_inner.loop_cond, dependent_outer_iv1) + .induction_var; Output add0 = ops::Add(root.WithOpName("add0"), dependent_inner_iv0, dependent_inner_iv1); @@ -638,46 +641,51 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], - "{#true,&,*iv_outer/cond:0}"); + "{#true,&,*iv_outer/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)], - "{(*iv_outer/cond:0 & {#true,&,*iv_outer/cond:0}),&," - "*iv_inner/cond:0}"); + "{(*iv_outer/cond:0 & " + "{#true,&,*iv_outer/cond:0}),&,*iv_inner/" + "cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)], - "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," - "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + "{{#true,&,(iv_outer/iv:0 & " + "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " + "*iv_inner/cond:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)], - "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," - "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + "{{#true,&,(iv_outer/iv:0 & " + "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " + "*iv_inner/cond:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," - "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + "{{#true,&,(iv_outer/iv:0 & " + "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " + "*iv_inner/cond:0)}"); } } TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { Scope root = Scope::NewRootScope().ExitOnError(); - InductionVarInfo iv_outer_0 = - CreateInductionVariable(root, "iv_outer_0", "frame", 0); - ops::Switch inner_value_0(root.WithOpName("outer_0_is_live"), - ops::Const(root.WithOpName("constant"), 5), - iv_outer_0.loop_cond); - InductionVarInfo iv_inner_0 = CreateInductionVariable( - root, "iv_inner_0", "frame", - ops::internal::Enter(root.WithOpName("iv_inner_0/enter"), - inner_value_0.output_true, "frame_inner")); - - InductionVarInfo iv_outer_1 = - CreateInductionVariable(root, "iv_outer_1", "frame", 1); - ops::Switch inner_init_value_1(root.WithOpName("outer_1_is_live"), - ops::Const(root.WithOpName("constant"), 5), - iv_outer_1.loop_cond); - InductionVarInfo iv_inner_1 = CreateInductionVariable( - root, "iv_inner_1", "frame", - ops::internal::Enter(root.WithOpName("iv_inner_1/enter"), - inner_init_value_1.output_true, "frame_inner")); - Output add0 = ops::Add(root.WithOpName("add0"), iv_inner_0.induction_var, - iv_inner_1.induction_var); + + std::array outer_iv; + std::array inner_iv; + + for (int i : {0, 1}) { + InductionVarInfo iv_outer = + CreateInductionVariable(root, "iv_outer", "outer_loop", 0); + Output enter_constant_outer_loop = ops::internal::Enter( + root.WithOpName("constant_enter_outer_loop"), + ops::Const(root.WithOpName("constant"), 5), "outer_loop", + ops::internal::Enter::Attrs().IsConstant(true)); + ops::Switch inner_value(root.WithOpName("outer_is_live"), + enter_constant_outer_loop, iv_outer.loop_cond); + InductionVarInfo iv_inner = CreateInductionVariable( + root, "iv_inner", "inner_loop", inner_value.output_true); + + outer_iv[i] = iv_outer.induction_var; + inner_iv[i] = iv_inner.induction_var; + } + + Output add0 = ops::Add(root.WithOpName("add0"), inner_iv[0], inner_iv[1]); VLogGraphIfAsked(*root.graph()); @@ -692,21 +700,77 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_0.induction_var)], - "{#true,&,*iv_outer_0/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_0.induction_var)], - "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&," - "*iv_inner_0/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_1.induction_var)], - "{#true,&,*iv_outer_1/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_1.induction_var)], - "{(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&," - "*iv_inner_1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[0])], + "{#true,&,*iv_outer/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[0])], + "{(*iv_outer/cond:0 & " + "{#true,&,*iv_outer/cond:0}),&,*iv_inner/" + "cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[1])], + "{#true,&,*iv_outer/cond_1:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[1])], + "{(*iv_outer/cond_1:0 & " + "{#true,&,*iv_outer/cond_1:0}),&,*iv_inner/" + "cond_1:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "({(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&," - "*iv_inner_1/cond:0} & " - "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&," - "*iv_inner_0/cond:0})"); + "({(*iv_outer/cond:0 & " + "{#true,&,*iv_outer/cond:0}),&,*iv_inner/" + "cond:0} & {(*iv_outer/cond_1:0 & " + "{#true,&,*iv_outer/cond_1:0}),&,*iv_inner/" + "cond_1:0})"); + } +} + +TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10); + InductionVarInfo iv_1 = CreateInductionVariable(root, "iv_1", "frame_1", 9); + + Output init = CreateSwitch(root, "init").output_true; + Output step = CreateSwitch(root, "step").output_true; + + std::array exits; + std::array next_iterations; + + for (int i : {0, 1}) { + Output init_enter = ops::internal::Enter( + root.WithOpName(absl::StrCat("init_enter_frame_", i)), init, + absl::StrCat("frame_", i), + ops::internal::Enter::Attrs().IsConstant(true)); + Output step_enter = ops::internal::Enter( + root.WithOpName(absl::StrCat("step_enter_frame_", i)), step, + absl::StrCat("frame_", i), + ops::internal::Enter::Attrs().IsConstant(true)); + + ops::Merge iv(root.WithOpName(absl::StrCat("expr_", i)), + {init_enter, init_enter}); + Output add = ops::Add(root.WithOpName(absl::StrCat("add_", i)), iv.output, + step_enter); + next_iterations[i] = ops::NextIteration( + root.WithOpName(absl::StrCat("expr_", i, "_next_iteration")), add); + EXPECT_TRUE( + root.graph() + ->UpdateEdge(next_iterations[i].node(), 0, iv.output.node(), 1) + .ok()); + exits[i] = ops::internal::Exit(root.WithOpName(absl::StrCat("exit_", i)), + iv.output); + } + + FixupSourceAndSinkEdges(root.graph()); + + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_NE(predicate_map[ControlOutputFor(exits[0])], + predicate_map[ControlOutputFor(exits[1])]); + EXPECT_NE(predicate_map[ControlOutputFor(exits[0])], ""); + EXPECT_NE(predicate_map[ControlOutputFor(exits[1])], ""); + + EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])], + predicate_map[ControlOutputFor(next_iterations[1])]); + EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])], ""); + EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[1])], ""); } } @@ -818,5 +882,82 @@ TEST(DeadnessAnalysisTest, RecvVsSwitchText) { EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)"); } +TEST(DeadnessAnalysisTest, DeMorgan) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output cond_0 = ops::Placeholder(root.WithOpName("cond_0"), DT_BOOL); + Output cond_1 = ops::Placeholder(root.WithOpName("cond_1"), DT_BOOL); + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + + ops::Switch sw_0(root.WithOpName("switch_0"), value, cond_0); + ops::Switch sw_1(root.WithOpName("switch_1"), value, cond_1); + + Output and_0_1 = + ops::Add(root.WithOpName("and_0_1"), sw_0.output_true, sw_1.output_true); + + Output or_not0_not1 = ops::Merge(root.WithOpName("or_not0_not1"), + {sw_0.output_false, sw_1.output_false}) + .output; + + // Predicate(should_always_be_dead) = + // (A & B) & (~A | ~B) = (A & B) & ~(A & B) = False + Output should_always_be_dead = + ops::Add(root.WithOpName("should_always_be_dead"), and_0_1, or_not0_not1); + + // Predicate(should_always_be_dead) = + // (A & B) | (~A | ~B) = (A & B) | ~(A & B) = True + Output should_always_be_alive = + ops::Merge(root.WithOpName("should_always_be_alive"), + {and_0_1, or_not0_not1}) + .output; + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_dead)], "#false"); + EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_alive)], "#true"); +} + +TEST(DeadnessAnalysisTest, ConstantTrueSwitchCondition) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output constant_true = ops::Const(root.WithOpName("const_true"), true); + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + ops::Switch sw(root.WithOpName("switch"), value, constant_true); + + Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false); + Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true); + + FixupSourceAndSinkEdges(root.graph()); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#false"); + EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#true"); +} + +TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output constant_false = ops::Const(root.WithOpName("const_false"), false); + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + ops::Switch sw(root.WithOpName("switch"), value, constant_false); + + Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false); + Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true); + + FixupSourceAndSinkEdges(root.graph()); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#true"); + EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index f478832781cb1dc045d9163d4a6f5e5f64a8a705..d0d7a3f3785469acd79a83b6897668f94fc6ea2e 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -779,7 +779,8 @@ Status Encapsulator::Subgraph::RecordArg( if (inserted) { NodeDef arg_def; NodeDefBuilder builder( - absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); + absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp, + NodeDebugInfo(src_node->def())); DataType dtype = edge->dst()->input_type(edge->dst_input()); builder.Attr("T", dtype); builder.Attr("index", arg_index); @@ -814,7 +815,8 @@ Status Encapsulator::Subgraph::RecordResult( if (inserted) { NodeDef ret_def; NodeDefBuilder builder( - absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); + absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp, + NodeDebugInfo(src_node->def())); DataType dtype = src_node->output_type(src_slot); builder.Attr("T", dtype); builder.Attr("index", ret_index); @@ -974,6 +976,7 @@ Status Encapsulator::Subgraph::AddHostComputes( } NodeDef host_compute_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", oc_subgraph_name, "_host_compute"), kHostComputeOp); @@ -1005,13 +1008,15 @@ Status Encapsulator::Subgraph::AddHostComputes( // subgraph. for (const auto& src_node : oc_subgraph.control_inputs) { Node* src_image = node_images.at(src_node); - graph_->AddControlEdge(src_image, host_compute); + graph_->AddControlEdge(src_image, host_compute, + /* allow_duplicates= */ true); } // Connect the _HostCompute node to its ancestor host compute nodes. for (const auto& ancestor_name : host_compute_ancestors) { Node* ancestor = host_compute_node[ancestor_name]; - graph_->AddControlEdge(ancestor, host_compute); + graph_->AddControlEdge(ancestor, host_compute, + /* allow_duplicates= */ true); } // Connect the consumers in the subgraph to the _HostCompute node. @@ -1028,7 +1033,8 @@ Status Encapsulator::Subgraph::AddHostComputes( // node. for (const auto& dst_node : oc_subgraph.control_outputs) { Node* dst_image = node_images.at(dst_node); - graph_->AddControlEdge(host_compute, dst_image); + graph_->AddControlEdge(host_compute, dst_image, + /* allow_duplicates= */ true); } } } @@ -1040,6 +1046,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, Graph* graph_out) { if (sequencer_ == nullptr) { NodeDef seq_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp"); builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); builder.Device(device_); @@ -1055,7 +1062,8 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { if (sequencer_ != nullptr) { VLOG(2) << "ConnectSequencerToCallNode"; - graph_out->AddControlEdge(sequencer_, call_node_); + graph_out->AddControlEdge(sequencer_, call_node_, + /* allow_duplicates= */ true); } } @@ -1214,7 +1222,8 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); NodeDef key_def; NodeDefBuilder builder( - absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder"); + absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder", + NodeDebugInfo(call_node_def_)); builder.Attr("dtype", DT_STRING); builder.Attr("shape", shape_proto); builder.Attr("_host_compute_call_node", call_node_def_.name()); @@ -1248,6 +1257,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( } NodeDef recv_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); @@ -1273,7 +1283,8 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( // completes. This has no effect on execution order but prevents the // RecvAtHost being pruned. TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); - graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_); + graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_, + true /* skip duplicates check */); return Status::OK(); } @@ -1303,6 +1314,7 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( } NodeDef send_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_send"), kSendFromHostOp); @@ -1329,7 +1341,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( // subgraph completes. This has no effect on execution order but prevents the // RecvAtHost being pruned. TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); - graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_); + graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_, + /* allow_duplicates= */ true); return Status::OK(); } @@ -1439,7 +1452,8 @@ Status Encapsulator::CopySubgraphEdges( src_func_id == dst_func_id) { Graph* g = subgraphs_[src_func_id].GetGraph(); if (edge->IsControlEdge()) { - g->AddControlEdge(src_image, dst_image); + g->AddControlEdge(src_image, dst_image, + /* allow_duplicates= */ true); } else { g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input()); } @@ -1725,7 +1739,8 @@ Status Encapsulator::CopyEdgeToOutputGraph( if (edges_added ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1)) .second) { - graph_out->AddControlEdge(src_image, dst_image); + graph_out->AddControlEdge(src_image, dst_image, + /* allow_duplicates= */ true); } return Status::OK(); @@ -1754,7 +1769,8 @@ Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { const string& subgraph = ancestors.first; for (const string& ancestor : ancestors.second) { graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNode(), - subgraphs_[subgraph].GetCallNode()); + subgraphs_[subgraph].GetCallNode(), + /* allow_duplicates= */ true); } } return Status::OK(); @@ -1833,8 +1849,9 @@ Node* AddDummyShapedNode(const Node* src_node, int src_port, // Add any Enter nodes required to bring the constant to the correct control // flow frame. while (!control_flow_info[src_node->id()].frame_name.empty()) { + NodeDebugInfo debug_info(*src_node); NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter", - options.op_registry()); + options.op_registry(), &debug_info); enter_builder.Attr("frame_name", control_flow_info[src_node->id()].frame_name); enter_builder.Attr("is_constant", true); @@ -2018,7 +2035,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( return errors::InvalidArgument( "Shape inference is not possible for outside_compilation " "SendFromHost node ", - send_node->name(), " because shape of node ", n->name(), + send_node->name(), " because shape of node ", + FormatNodeForError(*n), " will not be known at compilation time."); } } @@ -2047,8 +2065,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( return errors::Internal( "Internal assumption failed while rewriting an outside_compilation " "cluster that contains a while loop. Logic assumes back-edge is to " - "port 1 of a 2-input " - "Merge node."); + "port 1 of a 2-input Merge node."); } // Connect the existing edge to both inputs of the Merge node so that the // graph will be well-formed. @@ -2121,7 +2138,8 @@ Status CheckClusterDependencyForCycles( const string& ancestor, const string& successor, const std::unordered_map>& ancestors, const std::unordered_map& node_ancestors_map, - GraphCycles* cycle_detector, std::map* cycle_detector_map) { + GraphCycles* cycle_detector, + std::unordered_map* cycle_detector_map) { if (cycle_detector_map->find(ancestor) == cycle_detector_map->end()) { (*cycle_detector_map)[ancestor] = cycle_detector->NewNode(); } @@ -2165,7 +2183,7 @@ Status Encapsulator::FindClusterDependencies() { // We check that clusters are acyclic using this cycle detector. GraphCycles cycle_detector; // Map from cluster name to cycle detector node id. - std::map cycle_detector_map; + std::unordered_map cycle_detector_map; // Process the nodes in topologically-sorted order. std::vector nodes; GetReversePostOrder(*graph_in_, &nodes); @@ -2527,7 +2545,33 @@ Status EncapsulateSubgraphsPass::Run( std::vector* input_permutation, std::vector* output_permutation, NodeDef* node) { // Optimize the subgraph. - OptimizeGraph(flr, subgraph); + // Do not constant fold nodes that output DT_VARIANT type tensors. + // XLA does not support Const nodes of Variant type since it needs + // to know the original ops to be able to compile them to the relevant + // XLA form. + // TODO(srbs): This filter is a little conservative. E.g. a subgraph of + // the form: + // Const + // | + // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op + // | + // (Discard popped list) + // + // Would have been reduced to "Const -> Op" without this filter. + // However since we are only allowed to specify the filter at the "Node" + // level there is no good way to allow the above behavior. So we + // disallow any sort of constant folding on Variant nodes for now. + auto cf_consider_fn = [](const Node* n) { + for (const auto& output_arg : n->op_def().output_arg()) { + if (output_arg.type() == DT_VARIANT) { + return false; + } + } + return true; + }; + GraphOptimizer::Options graph_optimizer_options; + graph_optimizer_options.cf_consider_fn = cf_consider_fn; + OptimizeGraph(flr, subgraph, graph_optimizer_options); const int num_args = input_permutation->size(); std::vector const_args(num_args); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index de89be9a3555960dabe7bacd17226c15ae888ae6..261519de3478c8b3e30d206a15944b5a686598e2 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -32,6 +34,8 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -299,26 +303,10 @@ REGISTER_OP("XlaHostCompute") .Attr("Toutputs: list(type) >= 0") .Attr("ancestors: list(string) >= 0") .Attr("key: string") - .Attr("shape_inference_graph: string = ''") + .Attr("shape_inference_graph: func") .Attr("shapes: list(shape) >= 0") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); -REGISTER_OP("_XlaSendFromHost") - .Input("inputs: Tinputs") - .Input("dynamic_key: string") - .Attr("Tinputs: list(type) >= 0") - .Attr("key: string") - .Attr("device_ordinal: int") - .SetShapeFn(::tensorflow::shape_inference::UnknownShape); - -REGISTER_OP("_XlaRecvAtHost") - .Input("dynamic_key: string") - .Output("outputs: Toutputs") - .Attr("Toutputs: list(type) >= 0") - .Attr("key: string") - .Attr("device_ordinal: int") - .SetShapeFn(::tensorflow::shape_inference::UnknownShape); - REGISTER_OP("InputTest") .Output("o: float") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { @@ -510,12 +498,20 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, s = ConvertGraphDefToGraph(options, *graphdef, graph.get()); if (!s.ok()) return s; - s = PerformStaticShapeInferenceBeforeEncapsulation( - graph.get(), "_encapsulate", "_outside"); + s = PerformStaticShapeInferenceBeforeEncapsulation(graph.get()); if (!s.ok()) return s; - s = PreprocessForEncapsulation(graph.get(), "_encapsulate", "_outside"); - if (!s.ok()) return s; + // Create FunctionLibraryRuntime. + SessionOptions session_options; + std::vector> devices; + TF_CHECK_OK(DeviceFactory::AddDevices( + session_options, "/job:localhost/replica:0/task:0", &devices)); + OptimizerOptions opts; + auto device_mgr = absl::make_unique(std::move(devices)); + auto pflr = absl::make_unique( + device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(), + opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); std::unique_ptr graph_out; s = EncapsulateSubgraphsInFunctions( @@ -542,7 +538,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, std::map{}}); } s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters, - graph_out.get(), lib_def.get()); + graph_out.get(), flr, lib_def.get()); if (!s.ok()) return s; GraphDef graphdef_out; @@ -550,6 +546,14 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, graphdef->Swap(&graphdef_out); *library = lib_def->ToProto(); + // Remove "_xla_inferred_shapes" attr. They are added by + // `PerformStaticShapeInferenceBeforeEncapsulation`. + for (FunctionDef& fdef : *library->mutable_function()) { + for (NodeDef& node_def : *fdef.mutable_node_def()) { + node_def.mutable_attr()->erase("_xla_inferred_shapes"); + } + } + return s; } @@ -901,18 +905,22 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), shape.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, @@ -931,10 +939,11 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"c"}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -948,16 +957,18 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), b2.opts() .WithName("E") - .WithControlInputs({recv, b}) + .WithControlInputs({recv}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), @@ -966,9 +977,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(a).Input(b); Node* call = - b2.opts().WithControlInputs({s}).FinalizeBuilder(&node_builder); + b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({call})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1022,14 +1033,16 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape1.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } @@ -1037,33 +1050,45 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape2.opts()); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), shape2.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT, DT_FLOAT}, shape2.opts()); + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* g = Binary(e, ops::NodeOut(recv2, 0), + shape2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); Node* h = Binary(ops::NodeOut(recv2, 1), e, shape2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, shape2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g, h}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); } + NameAttrList shape_inference_graph1, shape_inference_graph2; + shape_inference_graph1.set_name("_outside_compilation_shape_inference_F1_O1"); + shape_inference_graph2.set_name("_outside_compilation_shape_inference_F1_O2"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}}, {{"I"}, "UnaryTest", - {"outside_compilation_O2_host_compute:outputs:0"}}, + {"outside_compilation_O2_host_compute:outputs:1"}}, {{"F"}, "BinaryTest", {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, @@ -1073,13 +1098,14 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { "XlaHostCompute", {"F:o:0", "D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O2"}, + {"shape_inference_graph", shape_inference_graph2}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O2"}}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"F"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", @@ -1088,13 +1114,15 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph1}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, - {{"i_0_retval_retval", "I:o:0"}}); + {{"g_0_retval_retval", "outside_compilation_O2_host_compute:outputs:0"}, + {"i_0_retval_retval", "I:o:0"}}); { std::unique_ptr lib_def( @@ -1105,19 +1133,22 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Binary(e, ops::NodeOut(recv2, 0), b2.opts() .WithName("G") @@ -1130,7 +1161,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); Node* send2 = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, b2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g, h}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* s = Sequencer(b2.opts() .WithName("F1_sequencer") @@ -1139,12 +1171,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(a).Input(b); - Node* call = b2.opts().WithControlInput(s).FinalizeBuilder(&node_builder); + Node* call = + b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder); - Binary(g, call, b2.opts().WithName("J")); + Binary(ops::NodeOut(call, 0), ops::NodeOut(call, 1), + b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } - TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); } @@ -1196,7 +1229,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, - {"f_0_retval_retval:float", "d_0_retval_retval:float"}, {}, + {"e_0_retval_retval:float", "f_0_retval_retval:float", + "d_0_retval_retval:float"}, + {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1212,35 +1247,41 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, - {{"d_0_retval_retval", "D:o:0"}, {"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"d_0_retval_retval", "D:o:0"}, + {"f_0_retval_retval", "F:o:0"}}); *library_expected.add_function() = FunctionDefHelper::Create( - "F2", {"f_0_arg:float", "bridge_e_g_0_arg:float"}, - {"i_0_retval_retval:float", "g_0_retval_retval:float"}, {}, + "F2", {"e_0_arg:float", "f_0_arg:float", "d_0_arg:float"}, + {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {}, { - {{"G"}, "BinaryTest", {"bridge_e_g_0_arg", "f_0_arg"}}, + {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}}, {{"I"}, "BinaryTest", {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {"G:o:0"}, - {{"Tinputs", absl::Span({DT_FLOAT})}, + {"d_0_arg", "G:o:0"}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, - {{"i_0_retval_retval", "I:o:0"}, {"g_0_retval_retval", "G:o:0"}}); + {{"g_0_retval_retval", "G:o:0"}, {"i_0_retval_retval", "I:o:0"}}); { std::unique_ptr lib_def( @@ -1251,16 +1292,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { Node* key_constant1 = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant1, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s1 = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), "F1"); @@ -1268,29 +1311,33 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = - b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); + b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1); Node* key_constant2 = KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1", - {DT_FLOAT}, b2.opts()); - Node* h = Binary(ops::NodeOut(call1, 1), recv2, + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant2, 0), "F2", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* h = Binary(recv2, ops::NodeOut(recv2, 1), b2.opts() .WithName("H") .WithAttr("_encapsulate", "F2") .WithAttr("_outside", "O1")); - Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, - b2.opts()); + Node* send2 = + SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* s2 = Sequencer( b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}), "F2"); NodeBuilder node_builder2("F2", "F2", lib_def.get()); - node_builder2.Input(call1).Input(e); + node_builder2.Input(call1) + .Input(ops::NodeOut(call1, 1)) + .Input(ops::NodeOut(call1, 2)); Node* call2 = b2.opts() - .WithControlInputs({s2, e, call1}) + .WithControlInputs({s2, call1}) .FinalizeBuilder(&node_builder2); - Binary(ops::NodeOut(call2, 1), call2, b2.opts().WithName("J")); + Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1326,8 +1373,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* h = Unary(g, b1.opts() .WithName("H") .WithAttr("_encapsulate", "F2") - .WithAttr("_outside", "O1") - .WithControlInput(e)); + .WithAttr("_outside", "O1")); Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2")); Binary(f, i, b1.opts().WithName("J")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); @@ -1358,10 +1404,12 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -1380,10 +1428,12 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"i_0_retval_retval", "I:o:0"}}); @@ -1401,7 +1451,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, @@ -1413,7 +1463,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = - b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); + b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1); Node* key_constant2 = KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); @@ -1422,8 +1472,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* h = Unary(recv2, b2.opts() .WithName("H") .WithAttr("_encapsulate", "F2") - .WithAttr("_outside", "O1") - .WithControlInput(e)); + .WithAttr("_outside", "O1")); Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, b2.opts()); @@ -1484,15 +1533,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {}, - {{"Tinputs", absl::Span({})}, + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -1503,16 +1554,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInput(send1), "F1"); + b2.opts().WithName("F1_sequencer").WithControlInputs({send1, recv1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = @@ -1569,15 +1623,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {}, - {{"Tinputs", absl::Span({})}, + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -1591,13 +1647,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = - RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, b2.opts()); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithControlInput(recv1) - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithControlInput(recv1) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( @@ -1644,8 +1700,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1654,14 +1729,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { "XlaHostCompute", {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, - {"Toutputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -1678,14 +1756,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); + Node* send1 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInput(recv1), "F1"); + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1722,8 +1803,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1736,14 +1836,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { "XlaHostCompute", {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, - {"Toutputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -1760,7 +1863,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, + Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts().WithControlInput(e)); Node* s1 = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), @@ -1770,7 +1873,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { Node* call1 = b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1813,22 +1916,45 @@ TEST(EncapsulateSubgraphsTest, FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape2.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, shape2.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts() .WithName("G") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, shape2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); } + NameAttrList shape_inference_graph1; + shape_inference_graph1.set_name("_outside_compilation_shape_inference_F1_O1"); + NameAttrList shape_inference_graph2; + shape_inference_graph2.set_name("_outside_compilation_shape_inference_F1_O2"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1836,6 +1962,18 @@ TEST(EncapsulateSubgraphsTest, {{"H"}, "UnaryTest", {"outside_compilation_O2_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_inference_graph1}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0"}, @@ -1843,12 +1981,14 @@ TEST(EncapsulateSubgraphsTest, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O2"}, + {"shape_inference_graph", shape_inference_graph2}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O2"}}}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -1856,30 +1996,39 @@ TEST(EncapsulateSubgraphsTest, GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, b2.opts()); - Node* g = Unary(recv, b2.opts() - .WithName("G") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O2") - .WithControlInput(e)); - Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, b2.opts()); - Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), - "F1"); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* g = Unary(recv2, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* send2 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* s1 = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, send1, recv2, send2}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("I")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1925,19 +2074,24 @@ TEST(EncapsulateSubgraphsTest, { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1945,6 +2099,18 @@ TEST(EncapsulateSubgraphsTest, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}}, {{"H"}, "UnaryTest", {"F:o:0"}}, + {{"outside_compilation_O2_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O2"}, + {"shape_inference_graph", NameAttrList()}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -1952,12 +2118,14 @@ TEST(EncapsulateSubgraphsTest, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -1968,27 +2136,33 @@ TEST(EncapsulateSubgraphsTest, Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); - Node* e = Unary(recv, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); - /*Node* g =*/Unary(a, b2.opts() - .WithName("G") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O2") - .WithControlInput(e)); - Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), - "F1"); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + /*Node* g =*/Unary(recv2, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* s1 = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, recv2, send}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("I")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2039,19 +2213,24 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, {{{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}}, @@ -2063,10 +2242,11 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -2074,9 +2254,11 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O2"}}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {}}, {{"outside_compilation_O3_host_compute"}, "XlaHostCompute", @@ -2085,11 +2267,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O3"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O3"}}, + {"_outside_compilation_subgraph", "O3"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {}}}, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -2100,23 +2285,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(recv1, b2.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, b2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Unary(recv2, b2.opts() .WithName("G") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2") .WithControlInput(e)); - Node* recv3 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3", - {DT_FLOAT}, b2.opts()); + Node* recv3 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); /*Node* i =*/Binary(recv3, e, b2.opts() .WithName("I") @@ -2131,7 +2320,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("J")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2167,14 +2356,46 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"D:o:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -2183,15 +2404,26 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(recv, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* s = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); - node_builder1.Input(a).Input(b); + node_builder1.Input(a).Input(b).ControlInput(s); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2236,20 +2468,22 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape.opts()); - Node* a = InputShaped(shape.opts().WithName("A")); - Node* c = Unary(a, shape.opts().WithName("C")); - Node* e = BinaryUnknownShape(c, recv, + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1), shape.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval_retval:float"}, {}, @@ -2262,15 +2496,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {"c:o:0"}, - {{"Tinputs", absl::Span({DT_FLOAT})}, + {"c_0_arg", "c:o:0"}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"c"}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -2285,16 +2520,18 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); - Node* e = BinaryUnknownShape(c, ops::NodeOut(recv, 0), + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1), b2.opts() .WithName("E") - .WithControlInputs({recv, b}) + .WithControlInputs({recv}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), @@ -2303,9 +2540,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(b).Input(c); Node* call = - b2.opts().WithControlInputs({s, c}).FinalizeBuilder(&node_builder); + b2.opts().WithControlInputs({s, b, c}).FinalizeBuilder(&node_builder); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({call})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index 1f4b9c90a4ff0b1166cdb7b5942771b350740ef3..2264806d6bdabd9f26d9f83b681524399f996317 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -62,517 +62,6 @@ void ReplaceAttr(Node* n, const string& attr_name, const T& value) { n->AddAttr(attr_name, value); } -// Step 1a ~ 1d for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges to remove. We should not remove the edge while iterating. - std::vector edges_to_remove; - for (const Edge* e : g->edges()) { - if (!e->IsControlEdge()) { - continue; - } - - auto src_xla_computation = - GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_xla_computation = - GetStringAttr(*e->dst(), xla_computation_attr_name); - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - - if (!src_xla_computation && !dst_xla_computation) { - continue; - } else if (src_xla_computation && !dst_xla_computation) { - if (src_outside_compilation) { - // Case 1c: outside compilation to host computation control edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } - } else if (!src_xla_computation && dst_xla_computation) { - if (dst_outside_compilation) { - // Case 1c: host computation control to outside compilation edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } - } else { // src_xla_computation && dst_xla_computation - if (*src_xla_computation != *dst_xla_computation) { - if (src_outside_compilation && dst_outside_compilation) { - // Case 1b: outside compilation to outside compilation control edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } else if (src_outside_compilation && !dst_outside_compilation) { - // Case 1a: outside compilation to another XLA computaition control - // edge. - TF_RETURN_IF_ERROR(AppendToListAttr( - e->src(), kXlaConnectedToOtherXlaComputationAttrName, - *dst_xla_computation)); - } else if (!src_outside_compilation && dst_outside_compilation) { - // Case 1a: another XLA computaition to outside compilation control - // edge. - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaConnectedFromOtherXlaComputationAttrName, - *src_xla_computation)); - } - } - } - } - - for (auto e : edges_to_remove) { - g->RemoveEdge(e); - } - return Status::OK(); -} - -// Step 2 for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessXlaToXlaDataEdges(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges between XLA computations. Notice that we do not store `Edge*` - // directly because we remove some nodes while adding Identity nodes, and - // those Edge pointers might be invalidated. - struct EdgeInfo { - int dst_input, dst_node_id; - }; - std::vector edges; - for (const Edge* e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - auto src_xla_computation = - GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_xla_computation = - GetStringAttr(*e->dst(), xla_computation_attr_name); - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - if (!src_xla_computation || !dst_xla_computation) { - continue; - } - - if (*src_xla_computation != *dst_xla_computation) { - if (src_outside_compilation || dst_outside_compilation) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); - VLOG(4) << "XLA -> XLA edge: " << e->DebugString(); - } - } - } - - // For each XLA -> XLA edge, add an Identity node between src and dst. - for (int i = 0; i < edges.size(); i++) { - Node* dst = g->FindNodeId(edges[i].dst_node_id); - const Edge* e; - TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); - Node* src = e->src(); - int src_output = e->src_output(), dst_input = e->dst_input(); - g->RemoveEdge(e); - - // Create Identity node, and connect it between `src` and `dst`. - string identity_node_name = - absl::StrCat("bridge_", src->name(), "_", dst->name()); - DataType dtype = src->output_type(src_output); - TF_ASSIGN_OR_RETURN(Node * identity_node, - BuildIdentityNode(g, identity_node_name, dtype, src, - /*requested_device=*/absl::nullopt)); - identity_node->AddAttr(kBridgeSourceNodeAttrName, src->name()); - g->AddEdge(src, src_output, identity_node, 0); - g->AddEdge(identity_node, 0, dst, dst_input); - - // Replace `e->dst()` because its input node changed. - NodeDef new_def = dst->def(); - *new_def.mutable_input(dst_input) = identity_node->name(); - TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); - - // Other edge in `edges` might have `e->dst()` as src or dst - // node. Before removing `e->dst()`, replace those edges with corresponding - // edges for `dst_replace_node`. - for (int j = i + 1; j < edges.size(); j++) { - if (edges[j].dst_node_id == edges[i].dst_node_id) { - edges[j].dst_node_id = dst_replace_node->id(); - } - } - } - return Status::OK(); -} - -// Step 3 for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges between outside compilation and host computation. Notice that - // we do not store `Edge*` directly because we remove some nodes while adding - // Identity nodes, and those Edge pointers might be invalidated. - struct EdgeInfo { - int dst_input, dst_node_id; - bool is_host_to_outside_compilation; - }; - std::vector edges; - for (const Edge* e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - if (e->src()->attrs().Find(xla_computation_attr_name) == nullptr && - e->dst()->attrs().Find(xla_computation_attr_name) != nullptr && - e->dst()->attrs().Find(outside_compilation_attr_name) != nullptr) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(), - /*is_host_to_outside_compilation=*/true}); - VLOG(4) << "Host -> oc edge: " << e->DebugString(); - } else if (e->dst()->attrs().Find(xla_computation_attr_name) == nullptr && - e->src()->attrs().Find(xla_computation_attr_name) != nullptr && - e->src()->attrs().Find(outside_compilation_attr_name) != - nullptr) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(), - /*is_host_to_outside_compilation=*/false}); - VLOG(4) << "Oc -> host edge: " << e->DebugString(); - } - } - - // Remove the edge from host to outside compilation. Add a placeholder as - // outside compilation node input. - std::map, Node*> placeholders; - for (int i = 0; i < edges.size(); i++) { - Node* dst = g->FindNodeId(edges[i].dst_node_id); - const Edge* e; - TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); - Node* src = e->src(); - int src_output = e->src_output(), dst_input = e->dst_input(); - g->RemoveEdge(e); - - // Find or create placeholder node. - string new_name = - edges[i].is_host_to_outside_compilation - ? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output) - : absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output); - auto placeholder_index = std::make_pair(src->name(), src_output); - auto iter = placeholders.find(placeholder_index); - Node* placeholder_node; - if (iter == placeholders.end()) { - NodeDefBuilder placeholder_builder(new_name, "Placeholder"); - placeholder_builder.Attr("dtype", src->output_type(src_output)); - if (edges[i].is_host_to_outside_compilation) { - placeholder_builder.Attr(kHostToOutsideCompilationOriginalNodeAttrName, - src->name()); - placeholder_builder.Attr(kHostToOutsideCompilationSrcOutputAttrName, - src_output); - // If this placeholder node is in outside compilation, we need to set - // `xla_computation_attr_name` and `outside_compilation_attr_name`. - string xla_computation_attr, outside_compilation_attr; - TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), xla_computation_attr_name, - &xla_computation_attr)); - TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), - outside_compilation_attr_name, - &outside_compilation_attr)); - placeholder_builder.Attr(xla_computation_attr_name, - xla_computation_attr); - placeholder_builder.Attr(outside_compilation_attr_name, - outside_compilation_attr); - } else { - placeholder_builder.Attr(kOutsideCompilationToHostOriginalNodeAttrName, - src->name()); - placeholder_builder.Attr(kOutsideCompilationToHostSrcOutputAttrName, - src_output); - } - NodeDef placeholder_def; - TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def)); - Status s; - placeholder_node = g->AddNode(placeholder_def, &s); - TF_RETURN_IF_ERROR(s); - placeholders[placeholder_index] = placeholder_node; - } else { - placeholder_node = iter->second; - } - g->AddEdge(placeholder_node, 0, dst, dst_input); - - // Replace `e->dst()` because its input node changed. - NodeDef new_def = dst->def(); - *new_def.mutable_input(dst_input) = placeholder_node->name(); - TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); - - // Other edge in `edges` might have `e->dst()` as src or dst - // node. Before removing `e->dst()`, replace those edges with corresponding - // edges for `dst_replace_node`. - for (int j = i + 1; j < edges.size(); j++) { - if (edges[j].dst_node_id == edges[i].dst_node_id) { - edges[j].dst_node_id = dst_replace_node->id(); - } - } - } - return Status::OK(); -} - -// Step 1 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -Status RemovePlaceholderBetweenOutsideCompilationAndHostComputation(Graph* g) { - // Gather all outside compilation to host computation nodes. - struct PlaceHolderNodeInfo { - Node* n; - bool is_host_to_oc; - }; - std::vector placeholder_nodes; - for (Node* n : g->nodes()) { - if (n->type_string() == "Placeholder") { - if (HasNodeAttr(n->def(), - kOutsideCompilationToHostOriginalNodeAttrName)) { - placeholder_nodes.push_back({n, false}); - } else if (HasNodeAttr(n->def(), - kHostToOutsideCompilationOriginalNodeAttrName)) { - placeholder_nodes.push_back({n, true}); - } - } - } - - // Remove the placeholder nodes, and reconnect original edge. - auto node_name_index = g->BuildNodeNameIndex(); - for (auto placeholder_iter : placeholder_nodes) { - Node* n = placeholder_iter.n; - - string node_name; - int node_src_output; - if (placeholder_iter.is_host_to_oc) { - TF_RETURN_IF_ERROR( - GetNodeAttr(n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, - &node_name)); - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), - kHostToOutsideCompilationSrcOutputAttrName, - &node_src_output)); - } else { - TF_RETURN_IF_ERROR( - GetNodeAttr(n->attrs(), kOutsideCompilationToHostOriginalNodeAttrName, - &node_name)); - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), - kOutsideCompilationToHostSrcOutputAttrName, - &node_src_output)); - } - auto iter = node_name_index.find(node_name); - if (iter == node_name_index.end()) { - return errors::Internal( - "Cannot find original node for oc -> host placeholder node ", - node_name); - } - - // Change all usage node to use the original node instead. - Node* original_node = iter->second; - std::vector control_edges; - std::vector data_edges; - for (auto e : n->out_edges()) { - if (e->IsControlEdge()) { - control_edges.push_back(e); - } else { - data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); - } - } - for (const Edge* e : control_edges) { - g->AddControlEdge(original_node, e->dst()); - g->RemoveEdge(e); - } - for (int i = 0; i < data_edges.size(); i++) { - Node* dst = data_edges[i].dst; - NodeDef new_def = dst->def(); - int dst_input = data_edges[i].dst_input; - *new_def.mutable_input(dst_input) = - absl::StrCat(original_node->name(), ":", node_src_output); - TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); - - const Edge* edge_to_replace = nullptr; - TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); - g->RemoveEdge(edge_to_replace); - g->AddEdge(original_node, node_src_output, replace_node, dst_input); - - // Other edges might have `dst` as dst node. Update those edges with - // `replace_node`. - for (int j = i + 1; j < data_edges.size(); j++) { - if (data_edges[j].dst == dst) { - data_edges[j].dst = replace_node; - } - } - - // Other placeholder node might have `dst` as original node. Update - // `node_name_index` with `replace_node`. - node_name_index[replace_node->name()] = replace_node; - } - - // Remove placeholder node. - g->RemoveNode(n); - } - return Status::OK(); -} - -// Step 2 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -Status RemoveIdentityBetweenDifferentXlaComputation(Graph* g) { - // Gather Identity nodes to remove. - std::vector bridge_nodes; - for (Node* n : g->nodes()) { - if (n->type_string() == "Identity" && - HasNodeAttr(n->def(), kBridgeSourceNodeAttrName)) { - bridge_nodes.push_back(n); - } - } - - // Remove the identity nodes, and reconnect the original edge. - for (int i = 0; i < bridge_nodes.size(); i++) { - Node* n = bridge_nodes[i]; - const Edge* src_edge = nullptr; - TF_RETURN_IF_ERROR(n->input_edge(0, &src_edge)); - - // Change all usage node to use the original node instead. - std::vector control_edges; - std::vector data_edges; - for (auto e : n->out_edges()) { - if (e->IsControlEdge()) { - control_edges.push_back(e); - } else { - data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); - } - } - for (const Edge* e : control_edges) { - g->AddControlEdge(src_edge->src(), e->dst()); - g->RemoveEdge(e); - } - for (int j = 0; j < data_edges.size(); j++) { - Node* dst = data_edges[j].dst; - NodeDef new_def = dst->def(); - int dst_input = data_edges[j].dst_input; - *new_def.mutable_input(dst_input) = - absl::StrCat(src_edge->src()->name(), ":", src_edge->src_output()); - TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); - - const Edge* edge_to_replace = nullptr; - TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); - g->RemoveEdge(edge_to_replace); - g->AddEdge(src_edge->src(), src_edge->src_output(), replace_node, - dst_input); - - // Other edges might have `dst` as dst node. Update those edges with - // `replace_node`. - for (int k = j + 1; k < data_edges.size(); k++) { - if (data_edges[k].dst == dst) { - data_edges[k].dst = replace_node; - } - } - - // The node we replaced might be in `bridge_nodes`. If so, update - // `bridge_nodes` to use the replaced node. - for (int k = i + 1; k < bridge_nodes.size(); k++) { - if (bridge_nodes[k] == dst) { - bridge_nodes[k] = replace_node; - } - } - } - - // Remove Identity node. - g->RemoveNode(n); - } - return Status::OK(); -} - -// Step 3 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -// We do not need to worry about removed nodes in step 1 and 2; -// `PreprocessForEncapsulation` will not record control dependencies for those -// remvoed nodes in the first place. -Status AddControlDependencies( - Graph* g, const std::unordered_map& cluster_node_names) { - auto node_name_index = g->BuildNodeNameIndex(); - - // Reconnect outside compilation to outside compilation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = - GetNodeAttr(n->attrs(), kXlaControlDependenciesAttrName, &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaControlDependenciesAttrName); - for (const string& control_input : control_deps) { - auto iter = node_name_index.find(control_input); - if (iter == node_name_index.end()) { - return errors::Internal("Cannot find original node for ", - control_input); - } - g->AddControlEdge(iter->second, n); - } - } - } - - // Reconnect outside compilation to XLA computation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = GetNodeAttr( - n->attrs(), kXlaConnectedToOtherXlaComputationAttrName, &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaConnectedToOtherXlaComputationAttrName); - for (const string& control_input : control_deps) { - auto iter = cluster_node_names.find(control_input); - if (iter == cluster_node_names.end()) { - return errors::Internal("Cannot find cluster node for ", - control_input); - } - auto iter2 = node_name_index.find(iter->second); - if (iter2 == node_name_index.end()) { - return errors::Internal("Cannot find cluster node for ", - iter->second); - } - g->AddControlEdge(n, iter2->second); - } - } - } - - // Reconnect XLA computation to outside compilation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = - GetNodeAttr(n->attrs(), kXlaConnectedFromOtherXlaComputationAttrName, - &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaConnectedFromOtherXlaComputationAttrName); - for (const string& control_input : control_deps) { - auto iter = cluster_node_names.find(control_input); - if (iter == cluster_node_names.end()) { - return errors::Internal("Cannot find cluster node for ", - control_input); - } - auto iter2 = node_name_index.find(iter->second); - if (iter2 == node_name_index.end()) { - return errors::Internal("Cannot find cluster node for ", - iter->second); - } - g->AddControlEdge(iter2->second, n); - } - } - } - - return Status::OK(); -} - // Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of // `PreprocessEdgesBetweenOutsideCompilations` for details. Status PreprocessControlEdgesBetweenOutsideCompilations( @@ -811,20 +300,6 @@ Status PostprocessControlEdgesBetweenOutsideCompilations( const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes"; -const char kXlaConnectedToOtherXlaComputationAttrName[] = - "_xla_connected_to_other_xla_computation"; -const char kXlaConnectedFromOtherXlaComputationAttrName[] = - "_xla_connected_from_other_xla_computation"; -const char kXlaControlDependenciesAttrName[] = "_xla_control_dependencies"; -const char kBridgeSourceNodeAttrName[] = "_xla_bridge_src"; -const char kOutsideCompilationToHostOriginalNodeAttrName[] = - "_xla_oc_to_host_node_name"; -const char kOutsideCompilationToHostSrcOutputAttrName[] = - "_xla_oc_to_host_src_output"; -const char kHostToOutsideCompilationOriginalNodeAttrName[] = - "_xla_host_to_oc_node_name"; -const char kHostToOutsideCompilationSrcOutputAttrName[] = - "_xla_host_to_oc_src_output"; const char kXlaConnectedToXlaComputationAttrName[] = "_xla_connected_to_xla_computation"; const char kXlaConnectedFromXlaComputationAttrName[] = @@ -835,32 +310,7 @@ const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output"; const char kXlaControlDependenciesWithinXlaClusterAttrName[] = "_xla_control_dependencies_within_xla_cluster"; -Status PerformStaticShapeInferenceBeforeEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Find all outside compilation to XLA computation data edges. - std::unordered_set outside_compilation_send_nodes; - for (auto e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - auto src_computation = GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_computation = GetStringAttr(*e->dst(), xla_computation_attr_name); - if (!src_computation || !dst_computation || - *src_computation != *dst_computation) { - continue; - } - - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - if (src_outside_compilation && !dst_outside_compilation) { - outside_compilation_send_nodes.insert(e->src()); - } - } - +Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) { // Perform shape inference. std::map arg_shapes; GraphShapeInfo shape_info; @@ -868,55 +318,21 @@ Status PerformStaticShapeInferenceBeforeEncapsulation( InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info)); // Add attribute for output shapes. - for (Node* n : outside_compilation_send_nodes) { - auto iter = shape_info.find(n->name()); - if (iter == shape_info.end()) { - continue; - } - + auto node_name_index = g->BuildNodeNameIndex(); + for (auto iter : shape_info) { std::vector output_shapes; - std::transform(iter->second.begin(), iter->second.end(), + std::transform(iter.second.begin(), iter.second.end(), std::back_inserter(output_shapes), [](const InferredShape& inferred_shape) { return inferred_shape.shape; }); + Node* n = node_name_index[iter.first]; n->AddAttr(kXlaInferredShapesAttrName, output_shapes); } return Status::OK(); } -Status PreprocessForEncapsulation(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - TF_RETURN_IF_ERROR(ProcessControlEdges(g, xla_computation_attr_name, - outside_compilation_attr_name)); - TF_RETURN_IF_ERROR(ProcessXlaToXlaDataEdges(g, xla_computation_attr_name, - outside_compilation_attr_name)); - TF_RETURN_IF_ERROR(ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( - g, xla_computation_attr_name, outside_compilation_attr_name)); - return Status::OK(); -} - -Status PostprocessForEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name, - const std::unordered_map& clusters) { - // The `node` pointer in `XlaClusterInfo` might be invalidated in step 1/2, - // but the node name won't change. Record cluster node name for - // `AddControlDependencies`. - std::unordered_map cluster_node_names; - for (const auto& iter : clusters) { - cluster_node_names[iter.first] = iter.second.node->name(); - } - - TF_RETURN_IF_ERROR( - RemovePlaceholderBetweenOutsideCompilationAndHostComputation(g)); - TF_RETURN_IF_ERROR(RemoveIdentityBetweenDifferentXlaComputation(g)); - TF_RETURN_IF_ERROR(AddControlDependencies(g, cluster_node_names)); - return Status::OK(); -} - Status PreprocessEdgesBetweenOutsideCompilations( Graph* g, const string& outside_compilation_attr_name) { // Remove edges from source node to outside compilation nodes, and edges diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index e363bc5754ac395bae262dc67a780a0173efaf5e..c9f16d14168163e11bb19092f566f1de8724aca3 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -27,51 +27,13 @@ namespace tensorflow { // a list of PartialTensorShape objects. extern const char kXlaInferredShapesAttrName[]; -// Infer output shapes for outside compilation nodes which have output data -// edges to XLA computation nodes. These shapes will be used later by XLA -// compiler as output shapes of the outside compilation's XlaHostCompute op. -// XLA computation nodes will be mark by attr `xla_computation_attr_name`; -// outside compilation nodes will be marked by both attr -// `xla_computation_attr_name` and `outside_compilation_attr_name`. -// -// Those outside compilation nodes will be marked with attribute -// `kXlaInferredShapesAttrName`. +// Infers output shapes for all nodes in graph `g`. The output shapes will be +// stored in node attribute `kXlaInferredShapesAttrName`. // // We have to perform shape inference before encapsulation because after // encapsulation, some nodes will be encapsulated into function call, and shape // inference does not handle function call at the moment. -Status PerformStaticShapeInferenceBeforeEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name); - -// Attribute indicating that some ops in other XLA computation has control -// dependency on this node. Attribute value will be a list of string (XLA -// computation names). -extern const char kXlaConnectedToOtherXlaComputationAttrName[]; - -// Attribute indicating that this node has control dependency on some ops in -// other XLA computation. Attribute value will be a list of string (XLA -// computation names). -extern const char kXlaConnectedFromOtherXlaComputationAttrName[]; - -// Attribute indicating that this node has control dependencies on some other -// nodes. Attribute value will be a list of string (node names). -extern const char kXlaControlDependenciesAttrName[]; - -// Attribute indicating that this is an Identity node added to act as a bridge -// between different XLA computations. Attribute value will be string (source -// node name). -extern const char kBridgeSourceNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an outside compilation node. Attribute value will be -// string (original input node name). -extern const char kOutsideCompilationToHostOriginalNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an outside compilation node. Attribute value will be -// int (src_output for original edge). -extern const char kOutsideCompilationToHostSrcOutputAttrName[]; +Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g); // Attribute indicating that some ops in this node's XLA computation has control // dependency on this node. Attribute value will always be "true". @@ -81,16 +43,6 @@ extern const char kXlaConnectedToXlaComputationAttrName[]; // this node's XLA computation. Attribute value will always be "true". extern const char kXlaConnectedFromXlaComputationAttrName[]; -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an host node. Attribute value will be string -// (original input node name). -extern const char kHostToOutsideCompilationOriginalNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for a host node. Attribute value will be int (src_output -// for original edge). -extern const char kHostToOutsideCompilationSrcOutputAttrName[]; - // Attribute indicating that this is an Placeholder node added to act as a // temporary input node for an outside compilation node. Attribute value will be // string (original input node name). @@ -106,27 +58,6 @@ extern const char kOutsideCompilationSrcOutputAttrName[]; // (node names). extern const char kXlaControlDependenciesWithinXlaClusterAttrName[]; -// Preprocesses edges between different XLA clusters for encapsulation. It will -// perform the following operations in order: -// -// 1a. For control edges between outside compilation and another XLA -// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName -// = XLA computation node name" to the outside compilation node. -// 1b. For control edges between different outside compilations (in different -// XLA computations), remove the edge and add attr -// "kXlaControlDependenciesAttrName = src node name" to dst node. -// 1c. For control edges between outside compilation and host computation, -// remove the edge and add attr "kXlaControlDependenciesAttrName = src node -// name" to dst node. -// 2. For data edges between different XLA computations, if either src or dst -// is outside compilation, add an Identity node in between the edge. The -// identity node will have attr kBridgeSourceNodeAttrName. -// 3. For data edges between outside compilation and host computation, remove -// the edge and create a Placeholder node as dst node's input. -Status PreprocessForEncapsulation(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name); - // Information for XLA computation. struct XlaClusterInfo { // Add an explicitly-defined default constructor for this class. @@ -158,24 +89,6 @@ struct XlaClusterInfo { const std::map host_compute_core; }; -// Postprocesses edges between different XLA clusters for encapsulation. This -// function reverts what `PreprocessForEncapsulation` did. It will perform the -// following operations in order: -// -// 1. Remove Placeholder nodes between outside compilation and host computation -// (created in `PreprocessForEncapsulation` step 3). -// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2. -// 3a. Reconnect control edges between outside compilation and another XLA -// computation (marked by `PreprocessForEncapsulation` step 1a). -// 3b. Reconnect control edges between different outside compilations (marked by -// `PreprocessForEncapsulation` step 1b). -// 3c. Reconnect control edges between outside compilation and host computation -// (marked by `PreprocessForEncapsulation` step 1c). -Status PostprocessForEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name, - const std::unordered_map& clusters); - // Preprocesses edges within the same XLA cluster. It will perform the following // operations in order: // diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc index 3b8b49cb92f3e453883a8e64e12ce3748a5173f6..6d1661222e3eaf9df4f9f91f2b426c80b55245b2 100644 --- a/tensorflow/compiler/jit/encapsulate_util_test.cc +++ b/tensorflow/compiler/jit/encapsulate_util_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -38,24 +37,11 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) { Graph g(OpRegistry::Global()); TF_CHECK_OK(s.ToGraph(&g)); - // "add" node is outside compilation node, "identity" node is XLA node. - auto node_index = g.BuildNodeNameIndex(); - Node *add_node = node_index["add"], *identity_node = node_index["identity"]; - add_node->AddAttr("_xla", "cluster"); - add_node->AddAttr("_oc", "cluster"); - identity_node->AddAttr("_xla", "cluster"); - TF_CHECK_OK( - PerformStaticShapeInferenceBeforeEncapsulation(&g, "_xla", "_oc")); + TF_CHECK_OK(PerformStaticShapeInferenceBeforeEncapsulation(&g)); - // Check that only "add" node now has _xla_inferred_shapes attr. - std::vector nodes_with_inferred_shape; - for (Node *n : g.nodes()) { - if (HasNodeAttr(n->def(), kXlaInferredShapesAttrName)) { - nodes_with_inferred_shape.push_back(n); - } - } - EXPECT_EQ(nodes_with_inferred_shape.size(), 1); - EXPECT_EQ(nodes_with_inferred_shape[0], add_node); + // Check that "add" node now has _xla_inferred_shapes attr. + auto node_index = g.BuildNodeNameIndex(); + Node *add_node = node_index["add"]; std::vector output_shapes; TF_CHECK_OK(GetNodeAttr(add_node->attrs(), kXlaInferredShapesAttrName, &output_shapes)); @@ -66,329 +52,4 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) { EXPECT_EQ(shape_proto.dim(0).size(), 2); } -TEST(PreprocessForEncapsulationTest, ControlEdges) { - // Build the graph: - // "const_0" and "const_1" in host computation - // "add" = "const_0" + "const_1" in XLA computation 0 - // "identity0" = "add" in XLA computation 0 & outside compilation 0 - // "identity1" = "identity0" in XLA computation 0 - // "identity2" = "identity1" in host computation - // "identity3" = "identity2" in XLA computation 1 - // "identity4" = "identity3" in XLA computation 1 & outside compilation 1 - // "identity5" = "identity4" in XLA computation 1 - // "identity6" = "identity5" in host computation - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); - Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); - Output add = ops::Add(s.WithOpName("add"), const_0, const_1); - Output identity0 = ops::Identity(s.WithOpName("identity0"), add); - Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0); - Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); - Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2); - Output identity4 = ops::Identity(s.WithOpName("identity4"), identity3); - Output identity5 = ops::Identity(s.WithOpName("identity5"), identity4); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr, and add control edges. - Node *const0_node = node_index["const_0"], *add_node = node_index["add"], - *identity0_node = node_index["identity0"], - *identity1_node = node_index["identity1"], - *identity2_node = node_index["identity2"], - *identity3_node = node_index["identity3"], - *identity4_node = node_index["identity4"], - *identity5_node = node_index["identity5"]; - add_node->AddAttr("_xla", "0"); - identity0_node->AddAttr("_xla", "0"); - identity0_node->AddAttr("_oc", "0"); - identity1_node->AddAttr("_xla", "0"); - identity3_node->AddAttr("_xla", "1"); - identity4_node->AddAttr("_xla", "1"); - identity4_node->AddAttr("_oc", "0"); - identity5_node->AddAttr("_xla", "1"); - // Case 1a: control edges between outside compilation and another XLA - // computation. - g.AddControlEdge(identity0_node, identity3_node); - g.AddControlEdge(identity1_node, identity4_node); - // Case 1b: control edges between different outside compilations. - g.AddControlEdge(identity0_node, identity4_node); - // Case 1c: control edges between outside compilation and host computation. - g.AddControlEdge(const0_node, identity0_node); - g.AddControlEdge(identity0_node, identity2_node); - - TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); - - // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name" - // to the outside compilation node. - std::vector attr; - TF_CHECK_OK(GetNodeAttr(identity0_node->def(), - kXlaConnectedToOtherXlaComputationAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "1"); - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity4_node->def(), - kXlaConnectedFromOtherXlaComputationAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "0"); - // Case 1b: add attr "_xla_control_deps = src node name" to dst node. - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity4_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "identity0"); - // Case 1c: add attr "_xla_control_deps = src node name" to dst node. - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity0_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "const_0"); - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity2_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "identity0"); -} - -TEST(PreprocessForEncapsulationTest, DataEdges) { - // Build the graph: - // "const_0" and "const_1" in host computation - // "identityn0" = ("const_0", "const_1") in host computation 0 - // "add0" = "const_0" + "const_1" in XLA computation 0 - // "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0 - // "identity0" = "add1" in XLA computation 0 - // "add2" = "add1" + "identity0" in host computation - // "add3" = "add1" + "add2" in XLA computation 1 - // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0 - // "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 & - // outside compilation 0 - // "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 & - // outside compilation 0 - // "identity1" = "add4" in XLA computation 1 - // "identity2" = "identity1" in host computation - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); - Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); - auto identityn0 = - ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1}); - Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1); - Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0); - Output identity0 = ops::Identity(s.WithOpName("identity0"), add1); - Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0); - Output add3 = ops::Add(s.WithOpName("add3"), add1, add2); - Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2); - Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]); - auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"), - {identityn0[0], identityn0[1]}); - Output identity1 = ops::Identity(s.WithOpName("identity1"), add4); - Output identity2 = ops::Identity(s.WithOpName("identity2"), add4); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr. - Node *add0_node = node_index["add0"], *add1_node = node_index["add1"], - *identity0_node = node_index["identity0"], - *add3_node = node_index["add3"], *add4_node = node_index["add4"], - *add5_node = node_index["add5"], - *identityn1_node = node_index["identityn_1"], - *identity1_node = node_index["identity1"]; - add0_node->AddAttr("_xla", "0"); - add1_node->AddAttr("_xla", "0"); - add1_node->AddAttr("_oc", "0"); - identity0_node->AddAttr("_xla", "0"); - add3_node->AddAttr("_xla", "1"); - add4_node->AddAttr("_xla", "1"); - add4_node->AddAttr("_oc", "0"); - add5_node->AddAttr("_xla", "1"); - add5_node->AddAttr("_oc", "0"); - identityn1_node->AddAttr("_xla", "1"); - identityn1_node->AddAttr("_oc", "0"); - identity1_node->AddAttr("_xla", "1"); - - TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); - - // Check input nodes for related data edges. - node_index = g.BuildNodeNameIndex(); - // Step 2: add an Identity node between different XLA computations. - Node *bridge_add1_add3 = node_index["bridge_add1_add3"]; - EXPECT_NE(bridge_add1_add3, nullptr); - string str; - TF_CHECK_OK( - GetNodeAttr(bridge_add1_add3->attrs(), kBridgeSourceNodeAttrName, &str)); - EXPECT_EQ(str, "add1"); - Node *bridge_identity0_add4 = node_index["bridge_identity0_add4"]; - EXPECT_NE(bridge_identity0_add4, nullptr); - // Step 3: add placeholder for edges between host computation and outside - // compilation. - EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0"); - Node *add1_oc_to_host_placeholder = - node_index["add1_oc_to_host_placeholder_0"]; - TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), - kOutsideCompilationToHostOriginalNodeAttrName, &str)); - EXPECT_EQ(str, "add1"); - int i; - TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), - kOutsideCompilationToHostSrcOutputAttrName, &i)); - EXPECT_EQ(i, 0); - add4_node = node_index["add4"]; - ASSERT_NE(add4_node, nullptr); - EXPECT_EQ(add4_node->def().input(0), - "bridge_identity0_add4_host_to_oc_placeholder_0"); - Node *identity0_host_to_oc_placeholder = - node_index["bridge_identity0_add4_host_to_oc_placeholder_0"]; - TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), - kHostToOutsideCompilationOriginalNodeAttrName, &str)); - EXPECT_EQ(str, "bridge_identity0_add4"); - TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), - kHostToOutsideCompilationSrcOutputAttrName, &i)); - EXPECT_EQ(i, 0); - - // Check different placeholder nodes are created for different src_output. - Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"], - *placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"]; - EXPECT_NE(placeholder0, nullptr); - EXPECT_NE(placeholder1, nullptr); - // Check we only have 2 placeholder nodes created for "identityn_0". - int placeholder_count = 0; - for (Node *n : g.nodes()) { - if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) { - string attr; - TF_CHECK_OK(GetNodeAttr( - n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr)); - if (attr == "identityn_0") { - ++placeholder_count; - } - } - } - EXPECT_EQ(placeholder_count, 2); -} - -TEST(PostprocessForEncapsulationTest, ControlEdges) { - // Build the graph: - // "const0" - // "identity0" = "const0" (XLA computation 0) - // "identity1" = "identity0" - // "identity2" = "identity1" (XLA computation 1) - // "identity3" = "identity2" - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const0 = ops::Const(s.WithOpName("const0"), 1, {}); - Output identity0 = ops::Identity(s.WithOpName("identity0"), const0); - Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0); - Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); - Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr, and add control edges. - Node *const0_node = node_index["const0"], - *identity0_node = node_index["identity0"], - *identity1_node = node_index["identity1"], - *identity2_node = node_index["identity2"], - *identity3_node = node_index["identity3"]; - identity1_node->AddAttr(kXlaConnectedFromOtherXlaComputationAttrName, - std::vector{"0"}); - identity1_node->AddAttr(kXlaConnectedToOtherXlaComputationAttrName, - std::vector{"1"}); - identity3_node->AddAttr(kXlaControlDependenciesAttrName, - std::vector{"const0", "identity1"}); - - std::unordered_map clusters; - clusters["0"].node = identity0_node; - clusters["1"].node = identity2_node; - TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters)); - - // Case 3a: we have control edge identity0 -> identity1, and identity1 -> - // identity2. - bool edge_identity0_identity1 = false, edge_identity1_identity2 = false; - for (const Edge *e : g.edges()) { - if (!e->IsControlEdge()) { - continue; - } - if (e->src() == identity0_node && e->dst() == identity1_node) { - edge_identity0_identity1 = true; - } else if (e->src() == identity1_node && e->dst() == identity2_node) { - edge_identity1_identity2 = true; - } - } - EXPECT_TRUE(edge_identity0_identity1); - EXPECT_TRUE(edge_identity1_identity2); - // Case 3b: we have control edge const0 -> identity3, and identity1 -> - // identity3. - bool edge_const0_identity3 = false, edge_identity1_identity3 = false; - for (const Edge *e : g.edges()) { - if (!e->IsControlEdge()) { - continue; - } - if (e->src() == const0_node && e->dst() == identity3_node) { - edge_const0_identity3 = true; - } else if (e->src() == identity1_node && e->dst() == identity3_node) { - edge_identity1_identity3 = true; - } - } - EXPECT_TRUE(edge_const0_identity3); - EXPECT_TRUE(edge_identity1_identity3); -} - -TEST(PostprocessForEncapsulationTest, DataEdges) { - // Build the graph: - // "const0" in outside compilation "0" - // "placeholder0" (for "const0") in host computation - // "add0" = "placeholder0" + "placeholder0" in host computation - // "placeholder1" (for "add0") in outside compilation 1 - // "add1" = "placeholder1" + "placeholder1" in outside compilation 1 - // - // "bridge" = "placeholder0" in host computation - // "placeholder2" (for "bridge") in outside compilation 1 - // "add2" = "placeholder2" + "placeholder2" in outside compilation 1 - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const0 = ops::Const(s.WithOpName("const0"), 1, {}); - Output placeholder0 = - ops::Placeholder(s.WithOpName("placeholder0"), DT_INT32); - Output add0 = ops::Add(s.WithOpName("add0"), placeholder0, placeholder0); - Output placeholder1 = - ops::Placeholder(s.WithOpName("placeholder1"), DT_INT32); - Output add1 = ops::Add(s.WithOpName("add1"), placeholder1, placeholder1); - Output bridge = ops::Identity(s.WithOpName("bridge"), placeholder0); - Output placeholder2 = - ops::Placeholder(s.WithOpName("placeholder2"), DT_INT32); - Output add2 = ops::Add(s.WithOpName("add2"), placeholder2, placeholder2); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set related attributes. - Node *placeholder0_node = node_index["placeholder0"]; - placeholder0_node->AddAttr(kOutsideCompilationToHostOriginalNodeAttrName, - "const0"); - placeholder0_node->AddAttr(kOutsideCompilationToHostSrcOutputAttrName, 0); - Node *placeholder1_node = node_index["placeholder1"]; - placeholder1_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName, - "add0"); - placeholder1_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0); - Node *bridge_node = node_index["bridge"]; - bridge_node->AddAttr(kBridgeSourceNodeAttrName, "const0"); - Node *placeholder2_node = node_index["placeholder2"]; - placeholder2_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName, - "bridge"); - placeholder2_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0); - - std::unordered_map clusters; - TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters)); - - // Result graph should be: - // "add0" = "const0" + "const0" - // "add1" = "add0" + "add0" - // "add2" = "const0" + "const0" - node_index = g.BuildNodeNameIndex(); - EXPECT_EQ(node_index.size(), 6); - EXPECT_EQ(node_index["add0"]->def().input(0), "const0:0"); - EXPECT_EQ(node_index["add0"]->def().input(1), "const0:0"); - EXPECT_EQ(node_index["add1"]->def().input(0), "add0:0"); - EXPECT_EQ(node_index["add1"]->def().input(1), "add0:0"); - EXPECT_EQ(node_index["add2"]->def().input(0), "const0:0"); - EXPECT_EQ(node_index["add2"]->def().input(1), "const0:0"); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index d334100aa4a915a87fb05d371e0e3379a7ee05f2..f0c9d573451952a398dce190e102a33270a4d739 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -15,13 +15,17 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -36,6 +40,25 @@ namespace { const char* const kXlaClusterOutput = "XlaClusterOutput"; +bool IsCpuGpuCompile(const Graph* graph) { + for (Node* n : graph->nodes()) { + string name; + // Only consider nodes being compiled. + if (!GetNodeAttr(n->attrs(), + EncapsulateXlaComputationsPass::kXlaClusterAttr, &name) + .ok()) + continue; + // Early return for any node with a device that is not a CPU or GPU. + DeviceNameUtils::ParsedName parsed; + if (DeviceNameUtils::ParseFullName(n->requested_device(), &parsed)) { + if (parsed.type != DEVICE_CPU && parsed.type != DEVICE_GPU) { + return false; + } + } + } + return true; +} + // Checks if a graph node is marked to be a guaranteed constant. bool is_guaranteed_constant(const Node& n) { bool guaranteed_constant = false; @@ -173,10 +196,11 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, // Nondeterminism in serialization would not lead to incorrect results, but // may cause spurious cache misses. DeterministicSerialization is a // best-effort deterministic serialization. - string serialized; - TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized)); - uint64 fingerprint = Fingerprint64(serialized); - LOG(INFO) << "Subgraph fingerprint:" << fingerprint; + const size_t size = gdef.ByteSizeLong(); + auto serialized = absl::make_unique(size); + TF_RET_CHECK(SerializeToBufferDeterministic(gdef, serialized.get(), size)); + uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size)); + VLOG(1) << "Subgraph fingerprint:" << fingerprint; call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); return Status::OK(); } @@ -297,6 +321,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, NodeDef def; def.set_name(launch->name()); + MergeDebugInfo(NodeDebugInfo(launch->def()), &def); // Target the XLA CPU/GPU backends. VLOG(2) << "Replacing with XlaLaunch"; @@ -350,12 +375,19 @@ Status EncapsulateXlaComputationsPass::Run( << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before", **options.graph, options.flib_def); - TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def)); + const char* additional_help = + IsCpuGpuCompile(options.graph->get()) + ? xla::status_macros::kPossibleAutoJitAlternative + : ""; + + TF_RETURN_WITH_CONTEXT_IF_ERROR(Encapsulate(options.graph, options.flib_def), + additional_help); VLOG(1) << "EncapsulateXlaComputations() half-way: " << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway", **options.graph, options.flib_def); - TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get())); + TF_RETURN_WITH_CONTEXT_IF_ERROR(BuildXlaLaunchOps(options.graph->get()), + additional_help); VLOG(1) << "EncapsulateXlaComputations() finished: " << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after", **options.graph, options.flib_def); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index e3c7e2f89be9b37b51a633dabb099969c181013f..2a770c527b2fae91352fd17dacb13495a3a73f34 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -20,14 +20,17 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" namespace tensorflow { @@ -98,9 +101,12 @@ xla::StatusOr BuildRecvAtHostNode( recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. - recv_at_host_builder.Attr("device_ordinal", 0); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + recv_at_host_builder.Attr("device_ordinal", device_ordinal_value); recv_at_host_builder.Attr( "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); + recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true); recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING); TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def)); Status s; @@ -197,9 +203,12 @@ xla::StatusOr BuildSendFromHostNode( send_from_host_builder.Attr("Tinputs", send_from_host_dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. - send_from_host_builder.Attr("device_ordinal", 0); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + send_from_host_builder.Attr("device_ordinal", device_ordinal_value); send_from_host_builder.Attr( "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); + send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true); std::vector inputs(send_from_host_dtypes.size()); for (auto* n : ret_nodes) { int index; @@ -300,6 +309,10 @@ xla::StatusOr BuildXlaHostComputeNodeDef( host_compute_builder.Attr("tpu_core", core); } + // Set input tokens. + host_compute_builder.Attr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + // Populate inputs. std::vector input_dtypes; TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes)); @@ -322,6 +335,38 @@ xla::StatusOr BuildXlaHostComputeNodeDef( return new_def; } +Status ValidateOutsideCompilationCallNode(Node* call_node) { + // DT_INT64 as input/output for outside compilation is not supported yet: + // b/120809951. + for (const Edge* e : call_node->in_edges()) { + if (e->IsControlEdge()) { + continue; + } + DataType dtype = e->src()->output_type(e->src_output()); + if (dtype == DT_INT64) { + return errors::Unimplemented( + "int64 input for outside compilation is not supported yet: " + "b/120809951. Please cast output of node ", + e->src()->DebugString(), + " to int32 before feeding it into outside compilation."); + } + } + for (const Edge* e : call_node->out_edges()) { + if (e->IsControlEdge()) { + continue; + } + DataType dtype = e->dst()->input_type(e->dst_input()); + if (dtype == DT_INT64) { + return errors::Unimplemented( + "int64 output for outside compilation is not supported yet: " + "b/120809951. Please cast input of node ", + e->dst()->DebugString(), + " to int32 before returning it from outside compilation."); + } + } + return Status::OK(); +} + // Replace outside compilation function call node with XlaHostCompute node. // If the function call node has no input/output edges, we will just remove it // and not create a XlaHostCompute node. @@ -357,6 +402,51 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( return Status::OK(); } +// Resets "device_ordinal" attr to placeholder value for related nodes +// (XlaRecvAtHost nodes; XlaSendFromHost nodes; If/While/FuncCall nodes +// containing XlaRecvAtHost/XlaSendFromHost). +Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + for (Node* n : g->nodes()) { + if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) { + continue; + } + + if (n->type_string() == "_XlaRecvAtHost" || + n->type_string() == "_XlaSendFromHost") { + n->ClearAttr("device_ordinal"); + n->AddAttr("device_ordinal", device_ordinal_value); + } else if (n->type_string() == "If") { + for (const string& attr_name : + std::vector{"then_branch", "else_branch"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + n->ClearAttr(attr_name); + n->AddAttr(attr_name, branch_func); + } + } else if (n->type_string() == "While") { + for (const string& attr_name : std::vector{"cond", "body"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + n->ClearAttr(attr_name); + n->AddAttr(attr_name, branch_func); + } + } else if (HasNodeAttr(n->def(), "device_ordinal")) { + // Function call node containing outside compilation. + n->ClearAttr("device_ordinal"); + n->AddAttr("device_ordinal", device_ordinal_value); + } else { + return errors::Internal("Unknown node marked with ", + kXlaHasHostTransferAttrName, ": ", + n->DebugString()); + } + } + return Status::OK(); +} + // For an XLA computation, builds host side graph given all outside compilation // graphs inside it. The host side graph contains: // 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and @@ -368,8 +458,8 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( Status ConstructHostGraph( const string& xla_cluster_name, const string& outside_compilation_attr_name, const std::vector& outside_compilation_host_graphs, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) { - host_graph->reset(new Graph(fld)); + FunctionLibraryDefinition* fld, const string& host_graph_func_name) { + Graph host_graph(fld); // Create sequencer node in host graph. NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"), @@ -378,24 +468,34 @@ Status ConstructHostGraph( NodeDef sequencer_def; TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def)); Status s; - Node* sequencer = (*host_graph)->AddNode(sequencer_def, &s); + Node* sequencer = host_graph.AddNode(sequencer_def, &s); TF_RETURN_IF_ERROR(s); // Create key placeholder in host graph. TF_ASSIGN_OR_RETURN( Node * key_placeholder, - AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get())); + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); // For each outside compilation graph, copy them to host graph with the // following changes: // a) Use key_placeholder in host graph instead of its own. - // b) Add control edge from RecvAtHost/SendFromHost to sequencer. + // b) Add control edge from host transfer nodes (XlaRecvAtHost, + // XlaSendFromHost, If/While nodes containing + // XlaRecvAtHost/XlaSendFromHost) to sequencer node. // c) Clear node_def.device(), so device placer won't get confused. for (const string& host_func : outside_compilation_host_graphs) { VLOG(4) << "Expanding host graph " << host_func; + // Temporarily use "0" as "device_ordinal". It will be reset to placeholder + // value after we expanded all host graphs. We cannot just use placeholder + // value here because FunctionDef instantiation does not allow placeholder + // value for attributes. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; FunctionBody* host_fbody = nullptr; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(host_func), AttrSlice(), fld, + *fld->Find(host_func), AttrSlice(&attrs), fld, [&](const string& op, const OpDef** sig) { return fld->LookUpOpDef(op, sig); }, @@ -408,8 +508,8 @@ Status ConstructHostGraph( FixupSourceAndSinkEdges(host_fbody->graph); std::map node_map; - node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node(); - node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node(); + node_map[host_fbody->graph->source_node()] = host_graph.source_node(); + node_map[host_fbody->graph->sink_node()] = host_graph.sink_node(); Status s; ReverseDFS( *host_fbody->graph, /*enter=*/nullptr, @@ -431,7 +531,7 @@ Status ConstructHostGraph( NodeDef copy_def = n->def(); // Change c). copy_def.clear_device(); - copy = (*host_graph)->AddNode(copy_def, &s); + copy = host_graph.AddNode(copy_def, &s); if (!s.ok()) { return; } @@ -446,22 +546,23 @@ Status ConstructHostGraph( e->src()->DebugString()); return; } - (*host_graph) - ->AddEdge(node_map[e->src()], e->src_output(), copy, - e->dst_input()); + host_graph.AddEdge(node_map[e->src()], e->src_output(), copy, + e->dst_input()); } // Change b). - if (copy->type_string() == "_XlaRecvAtHost" || - copy->type_string() == "_XlaSendFromHost") { - (*host_graph)->AddControlEdge(copy, sequencer); + if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) { + host_graph.AddControlEdge(copy, sequencer); } }, NodeComparatorID()); + if (!s.ok()) { return s; } } + // Reset "device_ordinal" to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(&host_graph)); // sequencer and key_placeholder might be dead nodes. Prune them if necessary. // - sequencer should be pruned iff it has no input control edges from @@ -470,21 +571,30 @@ Status ConstructHostGraph( // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost. // We don't need to do anything special. if (!sequencer->in_edges().empty()) { - (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node()); + host_graph.AddControlEdge(sequencer, host_graph.sink_node()); } PruneForReverseReachability( - host_graph->get(), - std::unordered_set{(*host_graph)->sink_node()}); + &host_graph, std::unordered_set{host_graph.sink_node()}); // Postprocess edges between different outside compilations. TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations( - host_graph->get(), outside_compilation_attr_name)); + &host_graph, outside_compilation_attr_name)); if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("extract_outside_compilation_host_graph_for_", xla_cluster_name), - **host_graph, fld); + host_graph, fld); + } + + FunctionDef host_graph_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(host_graph, host_graph_func_name, &host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(host_graph_fdef)); } return Status::OK(); @@ -492,8 +602,28 @@ Status ConstructHostGraph( // Expand XLA computation's outside compilation host side graph into main graph. // Add a control edge between sequencer node and the XLA computation node. -Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph, +Status ExpandHostGraphIntoMainGraph(Graph* main_graph, + FunctionLibraryDefinition* fld, + const string& host_graph_func_name, Node* xla_computation_node) { + // Temporarily use "0" as "device_ordinal". It will be rewritten with the + // correct value in a later pass. We cannot just use placeholder value here + // because FunctionDef instantiation does not allow placeholder value for + // attributes. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(host_graph_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* host_graph = fbody->graph; + // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse // reachable from sink node so all nodes will be copied. // TODO(b/77601805): consolidate copy graph functions. @@ -545,23 +675,25 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph, return s; } -// Rewrites shape inference graph for outside compilation. -// 1. If the outside compilation is a "top-level" one (not in a function of any -// If/While/etc.), this shape inference graph might have host computation to -// outside compilation placeholder nodes, which will cause shape inference to -// fail. However, those nodes are not in `host_graph` any more (because we -// have executed `PostprocessForEncapsultion`). In this case, we clear the -// graph, and copy SendFromHost with all its predecessors from `host_graph`. -// This case is detected by whether the SendFromHost node exists in -// `host_graph` as well. -// 2. Remove control edges, and prune nodes that are not useful for shape -// inference. +// Rewrites shape inference graph for outside compilation: +// 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from +// `host_graph`. Because we might still have outside compilation to outside +// compilation placeholder nodes in shape inference graph, which will prevent +// us from inferring XlaSendFromHost shape. But in `host_graph`, we already +// removed those placeholder nodes. +// 2) Remove control edges. +// 3) Prune nodes that are not useful for shape inference. Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, Graph* host_graph, FunctionLibraryDefinition* fld) { + // Use "0" as "device_ordinal". It does not matter for shape inference. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; FunctionBody* fbody = nullptr; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(shape_inference_graph_name), AttrSlice(), fld, + *fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld, [&](const string& op, const OpDef** sig) { return fld->LookUpOpDef(op, sig); }, @@ -650,6 +782,7 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, g->RemoveEdge(e); } } + // Nodes that are not reverse reachable from SendFromHost are not useful for // shape inference. Prune them. PruneForReverseReachability(g, @@ -669,6 +802,681 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, return Status::OK(); } +// Builds XlaSendToHost node which sends cond predicate to host. +xla::StatusOr BuildSendIfPredNode(const string& name, + const string& host_transfer_key, + Node* pred_node, Graph* g) { + NodeDefBuilder send_pred_builder(name, "XlaSendToHost"); + send_pred_builder.Attr("Tinput", DT_BOOL); + send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0")); + send_pred_builder.Attr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + send_pred_builder.Input(pred_node->name(), 0, DT_BOOL); + NodeDef send_pred_def; + TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def)); + Status s; + Node* send_pred_node = g->AddNode(send_pred_def, &s); + TF_RETURN_IF_ERROR(s); + g->AddEdge(pred_node, 0, send_pred_node, 0); + return send_pred_node; +} + +// Replaces key placeholder node with an _Arg node. +Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, + const string& func_name, + FunctionLibraryDefinition* fld) { + // Temporarily use "0" as "device_ordinal". It will be reset to placeholder + // value after rewriting. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* g = fbody->graph; + + // Find or create the key placeholder node. + Node* key_placeholder = nullptr; + for (Node* n : g->nodes()) { + if (IsKeyPlaceholderNode(*n)) { + key_placeholder = n; + break; + } + } + if (!key_placeholder) { + TF_ASSIGN_OR_RETURN(key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, g)); + } + + // Build the _Arg node, and replace key placeholder node with it. + NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp); + arg_builder.Attr("T", DT_STRING); + arg_builder.Attr("index", 0); + NodeDef arg_def; + TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def)); + TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status()); + + // Reset "device_ordinal" to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g)); + + FunctionDef replace_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, func_name, &replace_fdef)); + TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef)); + return Status::OK(); +} + +// Builds host side graph for If node. +Status BuildHostGraphForIfNode(const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, + const string& xla_cluster_name, + const string& if_node_name, + const string& host_transfer_key, + const string& host_graph_func_name, + FunctionLibraryDefinition* fld, + const string& then_branch_host_func_name, + const string& else_branch_host_func_name) { + Graph host_graph(fld); + string outside_compilation_name = absl::StrCat("oc_if_", if_node_name); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + + // Step 1: add key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); + + // Step 2: build XlaRecvAtHost node to recv predicate. + NodeDefBuilder recv_pred_builder( + absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost"); + recv_pred_builder.Attr("Toutputs", std::vector{DT_BOOL}); + recv_pred_builder.Attr("key", host_transfer_key); + recv_pred_builder.Attr("device_ordinal", device_ordinal_value); + recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + recv_pred_builder.Attr(outside_compilation_attr_name, + outside_compilation_name); + recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true); + recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING); + NodeDef recv_pred_def; + TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def)); + Status s; + Node* recv_pred_node = host_graph.AddNode(recv_pred_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0); + + // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key + // placeholder with an _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, then_branch_host_func_name, fld)); + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, else_branch_host_func_name, fld)); + + // Step 4: build If node to choose between `{then, else}_branch_host_graph`. + NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If"); + if_builder.Attr("Tcond", DT_BOOL); + if_builder.Attr("Tin", std::vector{DT_STRING}); + if_builder.Attr("Tout", std::vector{}); + NameAttrList host_then_branch, host_else_branch; + host_then_branch.set_name(then_branch_host_func_name); + (*host_then_branch.mutable_attr())["device_ordinal"] = device_ordinal_value; + host_else_branch.set_name(else_branch_host_func_name); + (*host_else_branch.mutable_attr())["device_ordinal"] = device_ordinal_value; + if_builder.Attr("then_branch", host_then_branch); + if_builder.Attr("else_branch", host_else_branch); + if_builder.Attr(kXlaHasHostTransferAttrName, true); + if_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + if_builder.Attr(outside_compilation_attr_name, outside_compilation_name); + if_builder.Input(recv_pred_node->name(), 0, DT_BOOL); + std::vector if_inputs{ + {key_placeholder->name(), 0, DT_STRING}}; + if_builder.Input(if_inputs); + NodeDef if_def; + TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def)); + Node* if_node = host_graph.AddNode(if_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(recv_pred_node, 0, if_node, 0); + host_graph.AddEdge(key_placeholder, 0, if_node, 1); + + // Convert `host_graph` to function, and add a "device_ordinal" attr. + FunctionDef oc_host_graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, + &oc_host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); + } + + return Status::OK(); +} + +// Rewrites loop cond to add a node which sends loop cond to host. +Status AddSendLoopPredToLoopCond(FunctionLibraryDefinition* fld, + const NameAttrList& loop_cond_func, + const string& while_node_name, + const string& host_transfer_key) { + // Instantiate the loop cond function. + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(loop_cond_func.name()), AttrSlice(&loop_cond_func.attr()), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* g = fbody->graph; + + // Find the _Retval node and the loop cond node. + Node* ret_node = nullptr; + for (Node* n : g->nodes()) { + if (n->type_string() == "_Retval") { + if (ret_node) { + return errors::Internal("Multiple return node for loop cond function ", + loop_cond_func.name(), ": ", + ret_node->DebugString(), " and ", + n->DebugString()); + } else { + ret_node = n; + } + } + } + if (!ret_node) { + return errors::Internal("No _Retval node for loop cond function ", + loop_cond_func.name()); + } + Node* loop_cond; + TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond)); + + // Build the XlaSendToHost node. + NodeDefBuilder send_loop_cond_builder( + absl::StrCat("send_oc_while_cond_", while_node_name), "XlaSendToHost"); + send_loop_cond_builder.Attr("Tinput", DT_BOOL); + send_loop_cond_builder.Attr("key", + absl::StrCat(host_transfer_key, "_dtoh_0")); + send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL); + NodeDef send_loop_cond_def; + TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def)); + Status s; + Node* send_loop_cond_node = g->AddNode(send_loop_cond_def, &s); + TF_RETURN_IF_ERROR(s); + g->AddEdge(loop_cond, 0, send_loop_cond_node, 0); + + // Replace original function. + FunctionDef replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*g, loop_cond_func.name(), &replace_fdef)); + TF_RETURN_IF_ERROR(fld->ReplaceFunction(loop_cond_func.name(), replace_fdef)); + + return Status::OK(); +} + +// Rewrites while loop cond function for host. +Status RewriteHostWhileLoopCond( + const string& cond_host_func_name, const string& while_node_name, + const string& host_transfer_key, const string& xla_cluster_attr_name, + const string& xla_cluster_name, const string& outside_compilation_attr_name, + const string& outside_compilation_name, FunctionLibraryDefinition* fld) { + // Replace key placeholder node with _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, cond_host_func_name, fld)); + + // Instantiate cond function. + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_temp_value; + FunctionBody* cond_fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(cond_host_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &cond_fbody)); + std::unique_ptr cond_fbody_deleter(cond_fbody); + Graph* cond_graph = cond_fbody->graph; + Node* key_arg = nullptr; + for (Node* n : cond_graph->nodes()) { + if (n->type_string() == "_Arg") { + key_arg = n; + } + } + if (!key_arg) { + return errors::Internal( + "No _Arg node found for host compute key in function ", + cond_host_func_name); + } + + // Add an XlaRecvAtHost node to use as cond function return value. + // We don't need to set kXlaHasHostTransferAttrName for this node, because + // it's already added for the "While" node on the host. + NodeDefBuilder recv_pred_builder( + absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost"); + recv_pred_builder.Attr("Toutputs", std::vector{DT_BOOL}); + recv_pred_builder.Attr("key", host_transfer_key); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + recv_pred_builder.Attr("device_ordinal", device_ordinal_value); + recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + recv_pred_builder.Attr(outside_compilation_attr_name, + outside_compilation_name); + recv_pred_builder.Input(key_arg->name(), 0, DT_STRING); + NodeDef recv_pred_def; + TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def)); + Status s; + Node* recv_pred_node = cond_graph->AddNode(recv_pred_def, &s); + TF_RETURN_IF_ERROR(s); + cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0); + NodeDefBuilder ret_builder( + absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval"); + ret_builder.Attr("T", DT_BOOL); + ret_builder.Attr("index", 0); + ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL); + NodeDef ret_def; + TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); + Node* ret_node = cond_graph->AddNode(ret_def, &s); + TF_RETURN_IF_ERROR(s); + cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0); + + // Reset device_ordinal to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph)); + + // Replace original function. + FunctionDef cond_replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*cond_graph, cond_host_func_name, &cond_replace_fdef)); + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef)); + + return Status::OK(); +} + +// Rewrites while loop body function for host. +Status RewriteHostWhileLoopBody( + const string& body_host_func_name, const string& while_node_name, + const string& host_transfer_key, const string& xla_cluster_attr_name, + const string& xla_cluster_name, const string& outside_compilation_attr_name, + const string& outside_compilation_name, FunctionLibraryDefinition* fld) { + // Replace key placeholder node with _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, body_host_func_name, fld)); + + // Instantiate body function. + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_temp_value; + FunctionBody* body_fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(body_host_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &body_fbody)); + std::unique_ptr body_fbody_deleter(body_fbody); + Graph* body_graph = body_fbody->graph; + Node* key_arg = nullptr; + for (Node* n : body_graph->nodes()) { + if (n->type_string() == "_Arg") { + key_arg = n; + } + } + if (!key_arg) { + return errors::Internal( + "No _Arg node found for host compute key in function ", + body_host_func_name); + } + + // Add a _Retval node to loop body. + NodeDefBuilder ret_builder( + absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval"); + ret_builder.Attr("T", DT_STRING); + ret_builder.Attr("index", 0); + ret_builder.Input(key_arg->name(), 0, DT_STRING); + NodeDef ret_def; + TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); + Status s; + Node* ret_node = body_graph->AddNode(ret_def, &s); + TF_RETURN_IF_ERROR(s); + body_graph->AddEdge(key_arg, 0, ret_node, 0); + + // Reset device_ordinal to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph)); + + // Replace original function. + FunctionDef body_replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body_graph, body_host_func_name, &body_replace_fdef)); + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(body_host_func_name, body_replace_fdef)); + + return Status::OK(); +} + +// Builds host side graph for while node. +Status BuildHostGraphForWhileNode( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const string& while_node_name, const string& host_transfer_key, + const string& host_graph_func_name, FunctionLibraryDefinition* fld, + const string& cond_host_func_name, const string& body_host_func_name) { + Graph host_graph(fld); + string outside_compilation_name = absl::StrCat("oc_while_", while_node_name); + + // Step 1: add key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); + + // Step 2: rewrite cond function. + TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond( + cond_host_func_name, while_node_name, host_transfer_key, + xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, + outside_compilation_name, fld)); + + // Step 3: rewrite body function. + TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody( + body_host_func_name, while_node_name, host_transfer_key, + xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, + outside_compilation_name, fld)); + + // Step 4: build While node. + NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name), + "While"); + while_builder.Attr("T", std::vector{DT_STRING}); + NameAttrList func; + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + (*func.mutable_attr())["device_ordinal"] = device_ordinal_value; + func.set_name(cond_host_func_name); + while_builder.Attr("cond", func); + func.set_name(body_host_func_name); + while_builder.Attr("body", func); + while_builder.Attr(kXlaHasHostTransferAttrName, true); + while_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + while_builder.Attr(outside_compilation_attr_name, outside_compilation_name); + std::vector while_inputs{ + {key_placeholder->name(), 0, DT_STRING}}; + while_builder.Input(while_inputs); + NodeDef while_def; + TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def)); + Status s; + Node* while_node = host_graph.AddNode(while_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(key_placeholder, 0, while_node, 0); + + // Convert `host_graph` to function. + FunctionDef oc_host_graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, + &oc_host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); + } + + return Status::OK(); +} + +// Builds host graph for func call nodes. +Status BuildHostGraphForFuncCallNode(const string& func_call_node_name, + const string& xla_cluster_name, + const string& func_call_host_func_name, + const string& host_graph_func_name, + FunctionLibraryDefinition* fld) { + Graph host_graph(fld); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + + // Step 1: add key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); + + // Step 2: rewrite `host_func_name`, replace key placeholder with an _Arg + // node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, func_call_host_func_name, fld)); + + // Step 3: build a function call node with `host_func_name`, with + // `key_placeholder` as input. + NodeDefBuilder call_builder(absl::StrCat("oc_call_", func_call_node_name), + func_call_host_func_name, fld); + call_builder.Input(key_placeholder->name(), 0, DT_STRING); + call_builder.Attr("device_ordinal", device_ordinal_value); + call_builder.Attr(kXlaHasHostTransferAttrName, true); + NodeDef call_def; + TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def)); + Status s; + Node* call_node = host_graph.AddNode(call_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(key_placeholder, 0, call_node, 0); + + // Convert `host_graph` to function, and add a "device_ordinal" attr. + FunctionDef oc_host_graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, + &oc_host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); + } + + return Status::OK(); +} + +Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( + Graph* g, const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const std::map& host_compute_core, FunctionLibraryRuntime* flr, + FunctionLibraryDefinition* fld, std::vector* host_graphs, + std::vector* shape_inference_graphs, + bool* has_outside_compilation) { + std::vector if_nodes, while_nodes, func_call_nodes; + for (Node* n : g->nodes()) { + if (n->type_string() == "If") { + if_nodes.push_back(n); + } else if (n->type_string() == "While") { + while_nodes.push_back(n); + } else if (fld->Contains(n->type_string())) { + func_call_nodes.push_back(n); + } else if (n->type_string() == FunctionLibraryDefinition::kGradientOp) { + // Only gradient for user-defined function should be considered as + // function call node. + NameAttrList original_func; + TF_RETURN_IF_ERROR(GetNodeAttr( + n->def(), FunctionLibraryDefinition::kFuncAttr, &original_func)); + if (fld->Contains(original_func.name())) { + func_call_nodes.push_back(n); + } + } + } + + for (Node* n : func_call_nodes) { + // Extract outside compilation for the function call. + bool func_has_outside_compilation = false; + NameAttrList func; + func.set_name(n->type_string()); + typedef protobuf::Map AttrMap; + *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end()); + string new_func_name = absl::StrCat(n->name(), "_oc"); + string host_func_name = absl::StrCat("oc_func_call_host_", n->name()); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + func, new_func_name, host_func_name, host_compute_core, flr, fld, + shape_inference_graphs, &func_has_outside_compilation)); + + // If the function call does not have outside compilation, nothing to do. + if (!func_has_outside_compilation) { + continue; + } + + *has_outside_compilation = true; + + // Change `n` to call the new function directly. + NodeDefBuilder replace_builder(n->name(), new_func_name, fld); + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + continue; + } + replace_builder.Input(e->src()->name(), e->src_output(), + e->src()->output_type(e->src_output())); + } + for (const auto& attr : n->attrs()) { + replace_builder.Attr(attr.first, attr.second); + } + NodeDef replace_def; + TF_RETURN_IF_ERROR(replace_builder.Finalize(&replace_def)); + TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, replace_def)); + replace->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + + // Build host side graph for the function call. + string oc_host_graph_name = + absl::StrCat("oc_func_host_graph_", replace->name()); + TF_RETURN_IF_ERROR( + BuildHostGraphForFuncCallNode(replace->name(), xla_cluster_name, + host_func_name, oc_host_graph_name, fld)); + + // Record the host graph. + host_graphs->push_back(oc_host_graph_name); + } + + for (Node* n : if_nodes) { + // Instantiate "then_branch" and "else_branch". + NameAttrList then_branch, else_branch; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch)); + + // Extract outside compilation for then_branch and else_branch. + bool then_branch_has_outside_compilation = false; + bool else_branch_has_outside_compilation = false; + string then_branch_host_func_name = + absl::StrCat("oc_then_branch_host_if_", n->name()), + else_branch_host_func_name = + absl::StrCat("oc_else_branch_host_if_", n->name()); + string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"), + else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc"); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + then_branch, then_branch_xla_func_name, then_branch_host_func_name, + host_compute_core, flr, fld, shape_inference_graphs, + &then_branch_has_outside_compilation)); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + else_branch, else_branch_xla_func_name, else_branch_host_func_name, + host_compute_core, flr, fld, shape_inference_graphs, + &else_branch_has_outside_compilation)); + + // If then/else branch do not have outside compilation, nothing to do. + if (!then_branch_has_outside_compilation && + !else_branch_has_outside_compilation) { + continue; + } + + *has_outside_compilation = true; + + // Change If node to call the new functions. + then_branch.set_name(then_branch_xla_func_name); + n->ClearAttr("then_branch"); + n->AddAttr("then_branch", then_branch); + else_branch.set_name(else_branch_xla_func_name); + n->ClearAttr("else_branch"); + n->AddAttr("else_branch", else_branch); + + string host_transfer_key = absl::StrCat("oc_if_pred_", n->name()); + + // XLA computation: add a SendToHost node to send cond predicate. + Node* pred_node; + TF_RETURN_IF_ERROR(n->input_node(0, &pred_node)); + TF_ASSIGN_OR_RETURN( + Node * send_pred_node, + BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()), + host_transfer_key, pred_node, g)); + n->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{send_pred_node->name()}); + + // Add a control edge from `send_pred_node` to If node, so XlaCompiler will + // visit If node after `send_pred_node`, thus the token output for + // `send_pred_node` has been generated. + g->AddControlEdge(send_pred_node, n); + + // Build host side graph for the "If" node. + string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name()); + TF_RETURN_IF_ERROR(BuildHostGraphForIfNode( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + n->name(), host_transfer_key, oc_host_graph_name, fld, + then_branch_host_func_name, else_branch_host_func_name)); + host_graphs->push_back(oc_host_graph_name); + } + + for (Node* n : while_nodes) { + // Instantiate "cond" and "body". + NameAttrList cond, body; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body)); + + // Extract outside compilation for cond and body. + bool cond_has_outside_compilation = false; + bool body_has_outside_compilation = false; + string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()), + body_host_func_name = absl::StrCat("oc_body_host_while_", n->name()); + string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"), + body_xla_func_name = absl::StrCat(body.name(), "_oc"); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr, + fld, shape_inference_graphs, &cond_has_outside_compilation)); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + body, body_xla_func_name, body_host_func_name, host_compute_core, flr, + fld, shape_inference_graphs, &body_has_outside_compilation)); + + // If cond/body do not have outside compilation, nothing to do. + if (!cond_has_outside_compilation && !body_has_outside_compilation) { + continue; + } + + *has_outside_compilation = true; + + // Change While node to call the new functions. + cond.set_name(cond_xla_func_name); + n->ClearAttr("cond"); + n->AddAttr("cond", cond); + body.set_name(body_xla_func_name); + n->ClearAttr("body"); + n->AddAttr("body", body); + + string host_transfer_key = absl::StrCat("oc_while_pred_", n->name()); + + // XLA computation: rewrite cond function to add a SendToHost node to send + // loop predicate. + TF_RETURN_IF_ERROR( + AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key)); + n->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + + // Build host side graph for the "While" node. + string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name()); + TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + n->name(), host_transfer_key, oc_host_graph_name, fld, + cond_host_func_name, body_host_func_name)); + host_graphs->push_back(oc_host_graph_name); + } + + return Status::OK(); +} + } // namespace Status RewriteOutsideCompilationSubgraphFn::operator()( @@ -755,12 +1563,15 @@ Status RewriteOutsideCompilationSubgraphFn::operator()( // it with HostCompute node later. AddNodeAttr("_outside_compilation_subgraph", old_name, node_def); if (shapes) { - AddNodeAttr("shape_inference_graph", "", node_def); + NameAttrList shape_inference_graph; + AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); AddNodeAttr("shapes", *shapes, node_def); } else { string shape_inference_func_name = absl::StrCat("_outside_compilation_shape_inference_", new_name); - AddNodeAttr("shape_inference_graph", shape_inference_func_name, node_def); + NameAttrList shape_inference_graph; + shape_inference_graph.set_name(shape_inference_func_name); + AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); AddNodeAttr("shapes", std::vector{}, node_def); } AddNodeAttr("ancestors", std::vector{}, node_def); @@ -775,36 +1586,34 @@ Status ExtractOutsideCompilationForFunction( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, - const std::map& host_compute_core, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph, - std::vector* shape_inference_graphs, + const string& host_graph_func_name, + const std::map& host_compute_core, FunctionLibraryRuntime* flr, + FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, bool* has_outside_compilation) { - // Early return if function does not have any outside compilation nodes. + // Convert the function to graph. const string& func_name = func_name_attrs.name(); - const FunctionDef* fdef = fld->Find(func_name); - if (!fdef) { - return errors::Internal("Cannot find function ", func_name); - } + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR( + flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle)); + Status ret_status = Status::OK(); + auto cleanup_handle = gtl::MakeCleanup([&]() { + auto s = flr->ReleaseHandle(handle); + if (!s.ok()) { + ret_status.Update(s); + } + }); + const FunctionBody* fbody = flr->GetFunctionBody(handle); + + // Check if we have outside compilation nodes. *has_outside_compilation = false; - for (auto& node_def : fdef->node_def()) { - if (HasNodeAttr(node_def, outside_compilation_attr_name)) { + for (Node* n : fbody->graph->nodes()) { + if (HasNodeAttr(n->def(), outside_compilation_attr_name)) { *has_outside_compilation = true; break; } } - if (!has_outside_compilation) { - return Status::OK(); - } - - // Convert the function to graph. - FunctionBody* fbody = nullptr; - TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(func_name), AttrSlice(&func_name_attrs.attr()), fld, - [&](const string& op, const OpDef** sig) { - return fld->LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); + // We cannot early return here, because we might have outside compilation in + // If/While function body. // Preprocess edges between different outside compilations. They will be // restored in `ConstructHostGraph()`. @@ -835,11 +1644,11 @@ Status ExtractOutsideCompilationForFunction( // If we could not infer shapes for XlaSendFromHost inputs statically, we // will set the "shape_inference_graph" attribute. In that case, copy // outside compilation subgraph as shape inference graph in `fld`. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph", &shape_inference_graph)); - if (!shape_inference_graph.empty()) { - shape_inference_graphs->push_back(shape_inference_graph); + if (!shape_inference_graph.name().empty()) { + shape_inference_graphs->push_back(shape_inference_graph.name()); const FunctionDef* xla_fdef = fld->Find(n->name()); if (!xla_fdef) { @@ -847,9 +1656,9 @@ Status ExtractOutsideCompilationForFunction( } FunctionDef shape_inference_fdef = *xla_fdef; shape_inference_fdef.mutable_signature()->set_name( - shape_inference_graph); - if (fld->Find(shape_inference_graph)) { - TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph, + shape_inference_graph.name()); + if (fld->Find(shape_inference_graph.name())) { + TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph.name(), shape_inference_fdef)); } else { TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); @@ -858,21 +1667,22 @@ Status ExtractOutsideCompilationForFunction( } } for (Node* n : outside_compilation_nodes) { + TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n)); TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode( graph_out.get(), n, host_compute_core)); } - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("extract_outside_compilation_for_func_after_", func_name), - *graph_out, fld); - } + + // Handle nodes with associated functions. + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions( + graph_out.get(), xla_cluster_attr_name, outside_compilation_attr_name, + xla_cluster_name, host_compute_core, flr, fld, + &outside_compilation_host_graphs, shape_inference_graphs, + has_outside_compilation)); // Construct host graph. - if (!outside_compilation_host_graphs.empty()) { - TF_RETURN_IF_ERROR( - ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, - outside_compilation_host_graphs, fld, host_graph)); - } + TF_RETURN_IF_ERROR(ConstructHostGraph( + xla_cluster_name, outside_compilation_attr_name, + outside_compilation_host_graphs, fld, host_graph_func_name)); // Remove the outside compilation graphs from function library. for (const string& func : outside_compilation_host_graphs) { @@ -883,20 +1693,31 @@ Status ExtractOutsideCompilationForFunction( FunctionDef updated_fdef; TF_RETURN_IF_ERROR( GraphToFunctionDef(*graph_out, new_func_name, &updated_fdef)); + const FunctionDef* original_fdef = fld->Find(func_name); + if (original_fdef) { + for (const auto& attr : original_fdef->attr()) { + (*updated_fdef.mutable_attr())[attr.first] = attr.second; + } + } if (fld->Find(new_func_name)) { TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, updated_fdef)); } else { TF_RETURN_IF_ERROR(fld->AddFunctionDef(updated_fdef)); } + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("extract_outside_compilation_for_func_after_", func_name), + *graph_out, fld); + } - return Status::OK(); + return ret_status; } Status ExtractOutsideCompilation( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters, Graph* g, - FunctionLibraryDefinition* fld) { + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) { if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile("extract_outside_compilation_before", *g, fld); } @@ -909,24 +1730,17 @@ Status ExtractOutsideCompilation( auto const& host_compute_core = iter.second.host_compute_core; bool has_outside_compilation; - std::unique_ptr host_graph; + string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name()); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - func_name_attrs, func_name_attrs.name(), host_compute_core, fld, - &host_graph, &shape_inference_graphs, &has_outside_compilation)); - if (host_graph) { - TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(g, host_graph.get(), n)); - } - } - - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("extract_outside_compilation_expanded", *g, - fld); + func_name_attrs, func_name_attrs.name(), host_graph_func_name, + host_compute_core, flr, fld, &shape_inference_graphs, + &has_outside_compilation)); + TF_RETURN_IF_ERROR( + ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n)); + TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name)); } - TF_RETURN_IF_ERROR(PostprocessForEncapsulation( - g, xla_cluster_attr_name, outside_compilation_attr_name, clusters)); - for (auto shape_inference_graph_name : shape_inference_graphs) { TF_RETURN_IF_ERROR( RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld)); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/tensorflow/compiler/jit/extract_outside_compilation_pass.h index 2a4f07cca213d999202024294f5d8f94527059c3..d64cc2a103ed040cbf413ac736f97f84459e869b 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.h +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.h @@ -88,9 +88,10 @@ Status ExtractOutsideCompilationForFunction( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, - const std::map& host_compute_core, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph, - std::vector* shape_inference_graphs, bool* has_outside_compilation); + const string& host_graph_func_name, + const std::map& host_compute_core, FunctionLibraryRuntime* flr, + FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, + bool* has_outside_compilation); // Rewrites XLA computation in `clusters` to replace outside compilation nodes // with XlaHostCompute, and moves those outside compilations into `g`. If shapes @@ -100,7 +101,7 @@ Status ExtractOutsideCompilation( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters, Graph* g, - FunctionLibraryDefinition* fld); + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index bff956100da661b679b4557fce53671e6cef88c5..7c3a24feff81b21a5d2347d21fb80988bc3e6065 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -19,8 +19,11 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/encapsulate_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" @@ -29,6 +32,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -109,10 +114,10 @@ TEST(RewriteOutsideCompilationSubgraphFnTest, Basic) { } EXPECT_TRUE(has_control_edge_to_send_from_host); // Verify step 7: necessary attrs added to call_node_def. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_CHECK_OK(GetNodeAttr(AttrSlice(&call_node_def.attr()), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, + EXPECT_EQ(shape_inference_graph.name(), "_outside_compilation_shape_inference_cluster_0"); } @@ -220,7 +225,42 @@ TEST(RewriteOutsideCompilationSubgraphFnTest, ShapesInferred) { EXPECT_EQ(shapes[0].dim_size(), 1); } -TEST(ExtractOutsideCompilationForFunctionTest, Basic) { +class ExtractOutsideCompilationForFunctionTest : public ::testing::Test { + public: + void SetUp() override { + SessionOptions session_options; + std::vector> devices; + TF_CHECK_OK(DeviceFactory::AddDevices( + session_options, "/job:localhost/replica:0/task:0", &devices)); + device_mgr_ = absl::make_unique(std::move(devices)); + } + + Status ExtractOutsideCompilationTest( + const string &xla_cluster_attr_name, + const string &outside_compilation_attr_name, + const string &xla_cluster_name, const NameAttrList &func_name_attrs, + const string &new_func_name, const string &host_graph_func_name, + const std::map &host_compute_core, + FunctionLibraryDefinition *fld, + std::vector *shape_inference_graphs, + bool *has_outside_compilation) { + OptimizerOptions opts; + pflr_ = absl::make_unique( + device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, fld, opts, + /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + return ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + func_name_attrs, new_func_name, host_graph_func_name, host_compute_core, + flr, fld, shape_inference_graphs, has_outside_compilation); + } + + private: + std::unique_ptr device_mgr_; + std::unique_ptr pflr_; +}; + +TEST_F(ExtractOutsideCompilationForFunctionTest, Basic) { // Build the XLA computation func. // "const0" // "identity0" = "const0" (outside compilation cluster "0") @@ -249,27 +289,26 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); // Get rewritten XLA computation function. - FunctionBody *fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), - AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); - auto node_name_index = fbody->graph->BuildNodeNameIndex(); + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + auto node_name_index = xla_fbody->graph->BuildNodeNameIndex(); // Check XlaHostCompute nodes. Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"]; @@ -292,18 +331,31 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { EXPECT_EQ(shapes[0].dim_size(), 1); // Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have // empty values. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, ""); + EXPECT_EQ(shape_inference_graph.name(), ""); TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, ""); + EXPECT_EQ(shape_inference_graph.name(), ""); // Check `shape_inference_graphs`. EXPECT_EQ(shape_inference_graphs.size(), 0); - // Check `host_graph`: verify we have key placeholder and sequencer. + // Check host graph: verify we have key placeholder and sequencer. + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; Node *key_placeholder = nullptr, *sequencer = nullptr; for (Node *n : host_graph->nodes()) { if (n->type_string() == "Placeholder" && @@ -348,7 +400,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { } } -TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { +TEST_F(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { // Build the XLA computation func. // "const0" FunctionDefLibrary fdl; @@ -365,25 +417,37 @@ TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); - // Check `host_graph` is empty. - EXPECT_FALSE(host_graph); + // Check host graph is empty. + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + EXPECT_EQ(host_graph->num_nodes(), 2); } -TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { +TEST_F(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { // Build the XLA computation func. // "const0" - // "const1" (outside compilation clsuter "0") + // "const1" (outside compilation cluster "0") FunctionDefLibrary fdl; { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -401,31 +465,43 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); // Check rewritten XLA graph: verify that we have no XlaHostCompute. - FunctionBody *fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), - AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); - for (Node *n : fbody->graph->nodes()) { + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + for (Node *n : xla_fbody->graph->nodes()) { EXPECT_NE(n->type_string(), "XlaHostCompute"); } - // Check `host_graph`: verify we have no placeholder, but we have "const1". + // Check host graph: verify we have no placeholder, but we have "const1". + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; int num_key_placeholders = 0; for (Node *n : host_graph->nodes()) { if (n->type_string() == "Placeholder" && @@ -438,4 +514,468 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { EXPECT_NE(node_name_index.find("const1"), node_name_index.end()); } +REGISTER_OP("XlaSendToHost") + .Input("input: Tinput") + .Attr("Tinput: type") + .Attr("key: string") + .SetIsStateful(); + +REGISTER_OP("XlaRecvFromHost") + .Output("output: Toutput") + .Attr("Toutput: type") + .Attr("shape: shape") + .Attr("key: string") + .SetIsStateful(); + +TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { + // Build the XLA computation func. + // "const0" (bool) + // "const1" (int32) + // "if0" (pred = "const0", input = "const1", then_branch = "true_fn", + // else_branch = "false_fn") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0); + Output identity = ops::Identity(s.WithOpName("identity_true_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_true_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_true_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *true_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "true_fn", true_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0); + Output identity = ops::Identity(s.WithOpName("identity_false_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_false_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_false_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *false_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "false_fn", false_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output cond = ops::Const(s.WithOpName("const0"), true, {2}); + Output input = ops::Const(s.WithOpName("const1"), 1, {2}); + NameAttrList true_fn; + true_fn.set_name("true_fn"); + NameAttrList false_fn; + false_fn.set_name("false_fn"); + auto if_op = ops::If(s.WithOpName("if"), cond, + std::initializer_list{cond, input}, {DT_INT32}, + true_fn, false_fn); + ops::_Retval retval(s.WithOpName("retval"), if_op.output[0], 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Check host graph. + { + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + auto node_name_index = host_graph->BuildNodeNameIndex(); + + // Verify we have XlaRecvAtHost to receive "If" predicate. + Node *recv_if_pred_node = node_name_index["recv_oc_if_pred_if"]; + EXPECT_NE(recv_if_pred_node, nullptr); + + // Verify we have an "If" to choose outside compilation between then_branch + // and else_branch, and it has `recv_if_pred_node` as cond input. + Node *if_oc_node = node_name_index["oc_if_if"]; + EXPECT_NE(if_oc_node, nullptr); + Node *if_oc_node_cond_input; + TF_CHECK_OK(if_oc_node->input_node(0, &if_oc_node_cond_input)); + EXPECT_EQ(if_oc_node_cond_input, recv_if_pred_node); + + // Check that then_branch outside compilation has node "identity_true_fn". + const FunctionDef *true_def = fld.Find("oc_then_branch_host_if_if"); + EXPECT_NE(true_def, nullptr); + bool has_identity_true_fn_node = false; + for (const auto &node_def : true_def->node_def()) { + if (node_def.name() == "identity_true_fn") { + has_identity_true_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_true_fn_node); + + // Check that else_branch outside compilation has node "identity_false_fn". + const FunctionDef *false_def = fld.Find("oc_else_branch_host_if_if"); + EXPECT_NE(false_def, nullptr); + bool has_identity_false_fn_node = false; + for (const auto &node_def : false_def->node_def()) { + if (node_def.name() == "identity_false_fn") { + has_identity_false_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_false_fn_node); + } + + // Check XLA graph. + { + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + Graph *xla_graph = xla_fbody->graph; + auto node_name_index = xla_graph->BuildNodeNameIndex(); + + // Check that we have XlaSendToHost to send cond predicate to host, and + // there is a control edge to If node. + Node *send_if_pred_node = node_name_index["send_oc_if_pred_if"]; + EXPECT_NE(send_if_pred_node, nullptr); + bool has_control_edge_to_if = false; + for (const Edge *e : send_if_pred_node->out_edges()) { + if (e->IsControlEdge() && e->dst()->name() == "if") { + has_control_edge_to_if = true; + break; + } + } + EXPECT_TRUE(has_control_edge_to_if); + + // Check that the "If" node now has `send_if_pred_node` as attribute + // _xla_token_input_nodes. + Node *if_node = node_name_index["if"]; + EXPECT_NE(if_node, nullptr); + std::vector token_inputs; + TF_CHECK_OK( + GetNodeAttr(if_node->def(), "_xla_token_input_nodes", &token_inputs)); + EXPECT_THAT(token_inputs, ::testing::ElementsAre("send_oc_if_pred_if")); + } +} + +TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { + // Build the XLA computation func. + // "const0" (bool) + // "while0" (input = "const0", cond = "cond_fn", body = "body_fn") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0); + Output identity = ops::Identity(s.WithOpName("identity_cond_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_cond_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_cond_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *cond_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cond_fn", cond_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0); + Output identity = ops::Identity(s.WithOpName("identity_body_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_body_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_body_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *body_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "body_fn", body_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = ops::Const(s.WithOpName("const0"), true, {2}); + NameAttrList cond_fn; + cond_fn.set_name("cond_fn"); + NameAttrList body_fn; + body_fn.set_name("body_fn"); + auto while_op = + ops::While(s.WithOpName("while"), std::initializer_list{input}, + cond_fn, body_fn); + ops::_Retval retval(s.WithOpName("retval"), while_op.output[0], 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Check host graph. + { + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + auto node_name_index = host_graph->BuildNodeNameIndex(); + + // Verify we have an "While" to execute outside compilation. + Node *while_oc_node = node_name_index["oc_while_while"]; + EXPECT_NE(while_oc_node, nullptr); + + // Check that cond outside compilation has node "identity_cond_fn". + const FunctionDef *cond_def = fld.Find("oc_cond_host_while_while"); + EXPECT_NE(cond_def, nullptr); + bool has_identity_cond_fn_node = false; + for (const auto &node_def : cond_def->node_def()) { + if (node_def.name() == "identity_cond_fn") { + has_identity_cond_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_cond_fn_node); + + // Check that body outside compilation has node "identity_body_fn". + const FunctionDef *body_def = fld.Find("oc_body_host_while_while"); + EXPECT_NE(body_def, nullptr); + bool has_identity_body_fn_node = false; + for (const auto &node_def : body_def->node_def()) { + if (node_def.name() == "identity_body_fn") { + has_identity_body_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_body_fn_node); + } + + // Check XLA graph. + { + // Verify that rewritten cond fn has XlaSendToHost to send loop predicate to + // host. + const FunctionDef *cond_def = fld.Find("cond_fn_oc"); + EXPECT_NE(cond_def, nullptr); + bool has_send_oc_while_cond_node = false; + for (const auto &node_def : cond_def->node_def()) { + if (node_def.name() == "send_oc_while_cond_while") { + has_send_oc_while_cond_node = true; + break; + } + } + EXPECT_TRUE(has_send_oc_while_cond_node); + } +} + +TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) { + // Build the XLA computation func. + // "const0" (int32) + // "fn" (input = "const0") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0); + Output identity = ops::Identity(s.WithOpName("identity"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *true_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "fn", true_fn_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + { + std::unique_ptr g(new Graph(&fld)); + + tensorflow::TensorProto tensor_proto; + tensor_proto.set_dtype(tensorflow::DT_INT32); + tensorflow::TensorShapeProto shape; + shape.add_dim()->set_size(2); + *tensor_proto.mutable_tensor_shape() = shape; + for (int i = 0; i < 2; ++i) { + tensor_proto.add_int_val(1); + } + NodeDef const_def; + TF_CHECK_OK(NodeDefBuilder("const", "Const") + .Attr("dtype", DT_INT32) + .Attr("value", tensor_proto) + .Finalize(&const_def)); + Status s; + Node *const_node = g->AddNode(const_def, &s); + TF_CHECK_OK(s); + + NodeDef fn_def; + TF_CHECK_OK(NodeDefBuilder("fn", "fn", &fld) + .Input("const", 0, DT_INT32) + .Finalize(&fn_def)); + Node *fn_node = g->AddNode(fn_def, &s); + TF_CHECK_OK(s); + g->AddEdge(const_node, 0, fn_node, 0); + + NodeDef ret_def; + TF_CHECK_OK(NodeDefBuilder("ret", "_Retval") + .Attr("index", 0) + .Attr("T", DT_INT32) + .Input("fn", 0, DT_INT32) + .Finalize(&ret_def)); + Node *ret_node = g->AddNode(ret_def, &s); + TF_CHECK_OK(s); + g->AddEdge(fn_node, 0, ret_node, 0); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + TF_CHECK_OK(fld.AddFunctionDef(*xla_fdef)); + } + + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Check host graph. + { + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + auto node_name_index = host_graph->BuildNodeNameIndex(); + + // Verify we have call node for outside compilation in `fn`. + Node *call_node = node_name_index["oc_call_fn"]; + EXPECT_NE(call_node, nullptr); + + FunctionBody *call_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("oc_func_call_host_fn"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &call_fbody)); + std::unique_ptr call_fbody_deleter(call_fbody); + + // Verify we have _XlaRecvAtHost and _XlaSendFromHost nodes. + bool has_recv = false, has_send = false; + for (Node *n : call_fbody->graph->nodes()) { + if (n->type_string() == "_XlaRecvAtHost") { + has_recv = true; + } else if (n->type_string() == "_XlaSendFromHost") { + has_send = true; + } + } + EXPECT_TRUE(has_recv); + EXPECT_TRUE(has_send); + } + + // Check XLA graph. + { + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + Graph *xla_graph = xla_fbody->graph; + auto node_name_index = xla_graph->BuildNodeNameIndex(); + + // Check that we have call node. + Node *fn_node = node_name_index["fn"]; + EXPECT_NE(fn_node, nullptr); + EXPECT_EQ(fn_node->type_string(), "fn_oc"); + + FunctionBody *call_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("fn_oc"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &call_fbody)); + std::unique_ptr call_fbody_deleter(call_fbody); + + // Verify we have XlaHostCompute nodes. + bool has_hc = false; + for (Node *n : call_fbody->graph->nodes()) { + if (n->type_string() == "XlaHostCompute") { + has_hc = true; + } + } + EXPECT_TRUE(has_hc); + } +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 98e344b3a080aa8aab27cd41564a90427bac151e..fba69dfccc31e01e73d8f86006b41ce5e3283f15 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -68,7 +68,12 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { Flag("tf_xla_fusion_only", &mark_for_compilation_flags->tf_xla_fusion_only, "enable fusion of element-wise operations only using XLA when " - "global_jit_level is ON*.")}; + "global_jit_level is ON*."), + Flag("tf_xla_disable_deadness_safety_checks_for_debugging", + &mark_for_compilation_flags + ->tf_xla_disable_deadness_safety_checks_for_debugging, + "Disable deadness related safety checks when clustering (this is " + "unsound).")}; flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end()); } @@ -89,6 +94,8 @@ void AllocateAndParseFlags() { mark_for_compilation_flags->tf_xla_clustering_fuel = std::numeric_limits::max(); mark_for_compilation_flags->tf_xla_fusion_only = false; + mark_for_compilation_flags + ->tf_xla_disable_deadness_safety_checks_for_debugging = false; device_flags = new XlaDeviceFlags; device_flags->tf_xla_compile_on_demand = false; diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 5ddea588eef5270880d91623dc05893da265960a..ed7810fcfd85c17db70d42e691446b60dc696939 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -25,27 +25,39 @@ namespace tensorflow { // Flags associated with the XLA bridge's mark_for_compilation_pass module. struct MarkForCompilationPassFlags { - int32 tf_xla_auto_jit; // Control compilation of operators into XLA - // computations on CPU and GPU devices. 0 = use - // ConfigProto setting; -1 = off; 1 = on for things - // very likely to be improved; 2 = on for everything. - // Experimental. - int32 tf_xla_min_cluster_size; // Minimum number of operators in an XLA - // compilation. Ignored for operators placed - // on an XLA device or operators explicitly - // marked for compilation. - int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA - // compilation. - bool tf_xla_clustering_debug; // Dump graphs during XLA compilation. - bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU - // via SessionOptions. - int64 tf_xla_clustering_fuel; // "Compiler fuel" for clustering. Only this - // many ops will be marked as eligible for - // clustering. - bool tf_xla_fusion_only; // This flag is effective only when global_jit_level - // is set to ON* and overrides its behavior. If - // true, enable fusion of element-wise operations - // only using XLA. + // Control compilation of operators into XLA computations on CPU and GPU + // devices. 0 = use ConfigProto setting; -1 = off; 1 = on for things very + // likely to be improved; 2 = on for everything. + // + // Experimental. + int32 tf_xla_auto_jit; + + // Minimum number of operators in an XLA compilation. Ignored for operators + // placed on an XLA device or operators explicitly marked for compilation. + int32 tf_xla_min_cluster_size; + + // Maximum number of operators in an XLA compilation. + int32 tf_xla_max_cluster_size; + + // Dump graphs during XLA compilation. + bool tf_xla_clustering_debug; + + // Enables global JIT compilation for CPU via SessionOptions. + bool tf_xla_cpu_global_jit; + + // "Compiler fuel" for clustering. Only this many ops will be marked as + // eligible for clustering. + int64 tf_xla_clustering_fuel; + + // tf_xla_fusion_only is effective only when global_jit_level is set to ON* + // and overrides its behavior. If true, enable fusion of element-wise + // operations only using XLA. + bool tf_xla_fusion_only; + + // If tf_xla_disable_deadness_safety_checks_for_debugging is set to true then + // we do not do deadness related safety checks. This is unsound in general, + // but can be used as a debugging aid. + bool tf_xla_disable_deadness_safety_checks_for_debugging; }; // Flags associated with the XLA bridge's xla_device module. diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index ce53f70b79d97ab087fefe542920b33f883632a2..5287fd175df206970b9fa73bc6b0176eddcdcaa9 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" +#include #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" @@ -144,7 +145,9 @@ SliceInputs MakeSliceIndexAndSizeInt64(const Scope& host_scope, // same constant value. This helps make the generated GraphDef more readable. class ConstantCache { public: - explicit ConstantCache(const Scope& s) : scope_(s) {} + explicit ConstantCache(const Scope& s, + const std::vector& control_deps) + : scope_(s), control_deps_(control_deps) {} Output Get1DHostConstant(int64 constant) { auto it = cache_.find(constant); @@ -152,6 +155,9 @@ class ConstantCache { Output new_const = ops::Const(scope_.WithOpName("const_", constant), {constant}); it = cache_.insert({constant, new_const}).first; + for (const Edge* e : control_deps_) { + scope_.graph()->AddControlEdge(e->src(), new_const.node()); + } } return it->second; } @@ -159,11 +165,13 @@ class ConstantCache { private: Scope scope_; std::unordered_map cache_; + std::vector control_deps_; }; // Returns a node computing the size of the Slice op with inputs `slice_inputs`. Status ComputeSliceSize(const Scope& host_scope, - const SliceInputs& slice_inputs, Output* size) { + const SliceInputs& slice_inputs, + std::vector control_deps, Output* size) { // If slice_size[i] >= 0 then slice_size[i] = slice_size[i]. // // If slice_size[i] == -1 then slice_size[i] = input_size[i] - @@ -183,7 +191,7 @@ Status ComputeSliceSize(const Scope& host_scope, ops::Shape(host_scope.WithOpName("input_shape"), slice_inputs.input, ops::Shape::OutType(DT_INT64)); - ConstantCache constant_pool(host_scope); + ConstantCache constant_pool(host_scope, control_deps); std::vector slice_size; for (int i = 0; i < slice_inputs.size_as_vector.size(); i++) { @@ -209,11 +217,16 @@ Status ComputeSliceSize(const Scope& host_scope, } // Trivial ConcatV2 nodes (with exactly one input) are disallowed. - *size = - slice_size.size() == 1 - ? slice_size[0] - : ops::Concat(host_scope.WithOpName("slice_size"), slice_size, - ops::Const(host_scope.WithOpName("concat_axis"), 0)); + if (slice_size.size() == 1) { + *size = slice_size[0]; + } else { + auto concat_axis = ops::Const(host_scope.WithOpName("concat_axis"), 0); + for (const Edge* e : control_deps) { + host_scope.graph()->AddControlEdge(e->src(), concat_axis.node()); + } + *size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size, + concat_axis); + } return Status::OK(); } @@ -234,12 +247,21 @@ Status ConvertTensorFlowSliceToStaticShapedSlice( .NewSubScope(absl::StrCat(slice->name(), "/static_shaped_slice")); Scope host_scope = main_scope.WithAssignedDevice(host_name); + // In the future we may want to be clever here and avoid the extra Cast ops. SliceInputs slice_inputs_int64 = MakeSliceIndexAndSizeInt64(host_scope, slice_inputs); + // Create a list of all control dependencies to be copied when possibly + // replacing nodes related to slice_size. + Node* old_size; + std::vector old_size_ctrl_deps; + TF_RETURN_IF_ERROR(slice->input_node(2, &old_size)); + absl::c_copy_if(old_size->in_edges(), std::back_inserter(old_size_ctrl_deps), + [](const Edge* e) { return e->IsControlEdge(); }); + Output slice_size; - TF_RETURN_IF_ERROR( - ComputeSliceSize(host_scope, slice_inputs_int64, &slice_size)); + TF_RETURN_IF_ERROR(ComputeSliceSize(host_scope, slice_inputs_int64, + old_size_ctrl_deps, &slice_size)); *result = ops::Slice(main_scope.WithAssignedDevice(slice->assigned_device_name()) @@ -291,9 +313,9 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs, return Status::OK(); } -// Return true if `n` is a slice we can rewrite to have a static shape +// Return true if `n` is a slice we should rewrite to have a static shape // (i.e. have the output shape only depend on the "size" input). -xla::StatusOr IsRewritableSlice(Node* n) { +xla::StatusOr ShouldRewriteSlice(Node* n) { if (n->type_string() != "Slice") { return false; } @@ -311,14 +333,20 @@ xla::StatusOr IsRewritableSlice(Node* n) { // If slice_size[i] < -1 for any i then executing the slice will throw an // error, and we don't do anything here. - return absl::c_all_of(slice_inputs->size_as_vector, - [](int64 size_i) { return size_i >= -1; }); + bool slice_size_has_error = absl::c_all_of( + slice_inputs->size_as_vector, [](int64 size_i) { return size_i >= -1; }); + if (!slice_size_has_error) { + return false; + } + + // No point in rewriting slices that have both size and begin as constants. + return !slice_inputs->begin.node()->IsConstant(); } Status FindAndRewriteSlices(Graph* g, bool* changed) { std::vector slices_to_rewrite; for (Node* n : g->nodes()) { - TF_ASSIGN_OR_RETURN(bool is_rewritable, IsRewritableSlice(n)); + TF_ASSIGN_OR_RETURN(bool is_rewritable, ShouldRewriteSlice(n)); if (is_rewritable) { slices_to_rewrite.push_back(n); } diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc index a2f1b831ad7605237e23c15cc43b337e06265553..2add2c13f92f561904163012ee16cc17ce5badce 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -401,5 +401,57 @@ TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceBegin) { Name("begin/static_shaped_slice/static_shaped_slice"))), _))); } + +// New constants being created need to have control dependencies copied to +// ensure correct control flow analysis in TF V2. +TEST(SliceToDynamicSliceRewriteTest, WithControlDepsToConstant) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size = ops::Const(root.WithOpName("size"), {-1}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + // Add an additional dependency that should still exist in with the new size + // variables. + Output dependency = ops::Placeholder(root.WithOpName("dependency"), DT_BOOL); + root.graph()->AddControlEdge(dependency.node(), size.node()); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + // Check that the new constants have control dependencies. + Node* const_0 = testing::FindNodeByName(result.get(), + "slice/static_shaped_slice/const_0"); + EXPECT_NE(const_0, nullptr); + EXPECT_THAT(const_0, + NodeWith(Op("Const"), CtrlDeps(NodeWith(Op("Placeholder"), + Name("dependency"))))); +} + +TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithConstBegin) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Const(root.WithOpName("begin"), {10, 10}); + Output size = ops::Const(root.WithOpName("size"), {-1, 500}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* slice_node = testing::FindNodeByName(result.get(), "slice"); + EXPECT_THAT(slice_node, + NodeWith(Op("Slice"), Inputs(Out(NodeWith(Op("Placeholder"))), + Out(NodeWith(Op("Const"))), + Out(NodeWith(Op("Const")))))); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 0583774714c6db7a2fa515fc8a0d304e1898db97..d0fa2c40be9d6b13ec736a9d6483dae0b4f0f45e 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -19,12 +19,14 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index ad71df5a694a5f8da94675049df1062a7edb6253..997ef6e14bb9bd16ddac13eaf67368966818b29e 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/function.h" @@ -35,6 +36,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/util/stream_executor_util.h" @@ -304,10 +307,19 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { xla::LocalExecutable* executable; std::map variables; - OP_REQUIRES_OK( - ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, - constants_, /*lazy=*/false, &client, - &variables, &kernel, &executable)); + { + Status s = CompileToLocalExecutable( + ctx, function_, platform_info_, resources_, constants_, /*lazy=*/false, + &client, &variables, &kernel, &executable); + if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU || + platform_info_.device_type().type_string() == DEVICE_GPU)) { + // Suggest auto jit if the failure was with GPU or CPU. + errors::AppendToMessage(&s, + xla::status_macros::kPossibleAutoJitAlternative); + } + + OP_REQUIRES_OK(ctx, s); + } se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 6618e3a58ab7b6374ed775cd6e4e18a6a4975588..d9a83049d6352f04f9237f21b44bdb5ea18e518a 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -41,7 +42,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" -#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -677,12 +678,28 @@ Status MarkForCompilationPass::Run( VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; const FunctionLibraryDefinition* fld = options.flib_def; + // Deadness analysis expects a graph with source and sink edges properly + // connected but sometimes the incoming graph does not follow this invariant. + // So fix up the source and sink edges before calling into deadness analysis. + FixupSourceAndSinkEdges(options.graph->get()); + std::unique_ptr deadness; { XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1); TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness)); } + bool deadness_analysis_disabled = + GetMarkForCompilationPassFlags() + ->tf_xla_disable_deadness_safety_checks_for_debugging; + + if (deadness_analysis_disabled) { + LOG(WARNING) << "Deadness analysis was manually disabled via " + "--tf_xla_disable_deadness_safety_checks_for_debugging; " + "auto-clustering " + "is unsound!"; + } + auto is_compilable = [&](const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), @@ -715,9 +732,12 @@ Status MarkForCompilationPass::Run( // and some are dead) then don't compile it. XLA cannot represent the // deadness semantics of these nodes correctly and auto-clustering these // nodes can cause deadness to propagate to nodes that should be live. - if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { - VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; - return false; + if (!deadness_analysis_disabled) { + if (node->IsMerge() || + deadness->HasInputsWithMismatchingDeadness(*node)) { + VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; + return false; + } } // Check for fusable ops only if requested. @@ -1145,6 +1165,29 @@ Status MarkForCompilationPass::RunImpl( if (flags->tf_xla_clustering_debug) { dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph, options.flib_def); + + // We also dump out an annoated version of the TF graph where the nodes + // names are prefixed with the cluster names. This can help visualizing the + // clustering decisions on TensorBoard. + Graph new_graph((*options.graph)->op_registry()); + CopyGraph(**options.graph, &new_graph); + + for (Node* n : new_graph.nodes()) { + if (absl::optional cluster_name = + GetXlaClusterForNode(*n)) { + n->set_name(absl::StrCat(*cluster_name, "/", n->name())); + } else if (n->type_string() == "VarHandleOp") { + n->set_name(absl::StrCat("varhandle/", n->name())); + } else { + // There is room for improvement here. In particular, it may help to + // split these unclustered nodes into classes where every node in a + // specific class has edges to and from the same set of clusters. + n->set_name(absl::StrCat("unclustered/", n->name())); + } + } + + dump_graph::DumpGraphToFile("mark_for_compilation_annotated", new_graph, + options.flib_def); } VLogClusteringSummary(*graph); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index bf2c5508ea9e987e80093f4c2e15d3ff5191126f..c2b6250f738fafa35b2c5f79e97cf1281b50a316 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -151,7 +151,7 @@ TEST(XlaCompilationTest, CompilableCycles) { EXPECT_EQ(clusters["A"], clusters["C"]); } -TEST(XlaCompilationTest, Complex128Unsupported) { +TEST(XlaCompilationTest, StringUnsupported) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; { @@ -159,10 +159,10 @@ TEST(XlaCompilationTest, Complex128Unsupported) { Node* a = ops::SourceOp( "Const", builder.opts() .WithName("A") - .WithAttr("dtype", DT_COMPLEX128) - .WithAttr("value", Tensor(DT_COMPLEX128, TensorShape()))); - Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); - ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); + .WithAttr("dtype", DT_STRING) + .WithAttr("value", Tensor(DT_STRING, TensorShape()))); + Node* b = ops::UnaryOp("EncodeBase64", a, builder.opts().WithName("B")); + ops::BinaryOp("StringSplit", a, b, builder.opts().WithName("C")); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 42ea3926e16ae791dbe1bede3b8742383db7667c..e1fd2aaee2822daeffb415d053c9c4f56002a856 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -120,6 +120,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { NodeDef ndef = n->def(); ndef.set_name(absl::StrCat(n->name(), "/declustered")); + MergeDebugInfo(NodeDebugInfo(n->def()), &ndef); RemoveFromXlaCluster(&ndef); Status s; Node* cloned_node = graph->AddNode(ndef, &s); diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 38a54cc5efae35ad77b6dc8039c653e920cfc071..1d81a8f4fcbf050663626b1f7660afd71f4027bc 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" -#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index 80c691fe490c1092315708a2da754d367d585300..a27e0d9f2a6ecddfdbdb29be673084d77a178d8a 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -53,7 +53,15 @@ Status PropagateShapes(const Graph& graph, // shapes, even if no shape function is registered for a node. Status status = shape_refiner->AddNode(n); if (!status.ok()) { - VLOG(1) << "Shape inference failed for node: " << status; + VLOG(1) << "Shape inference failed for node " << n->name() << ": " + << status; + } else { + shape_inference::InferenceContext* context = shape_refiner->GetContext(n); + for (int i = 0; i < n->num_outputs(); i++) { + shape_inference::ShapeHandle handle = context->output(i); + VLOG(4) << "Output " << i << " for node " << n->name() << ": " + << context->DebugString(handle); + } } if (n->type_string() == "_Arg") { diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index fef28fc810cb4e544fe3f271f0b96cebd8a96779..3adcfef4dacecb343812cefc3a893a65c74ca101 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -19,9 +19,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -43,7 +43,7 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, return ""; } - auto node_name = [cycles, &graph](int node_id) { + auto node_name = [&graph](int node_id) { if (!FastBoundsCheck(node_id, graph.num_node_ids())) { return string("(null)"); } diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 3df5479a55e841380ca7b8cdd0add9fd17487091..611515cf33bc1abe21e06eb7f1513800276e095b 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -38,6 +39,8 @@ limitations under the License. namespace tensorflow { +constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold; + XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client, DeviceType device_type) : client_(client), device_type_(std::move(device_type)) {} @@ -60,7 +63,7 @@ XlaCompilationCache::~XlaCompilationCache() { // about? } -string XlaCompilationCache::DebugString() { +string XlaCompilationCache::DebugString() const { return "XLA JIT compilation cache"; } @@ -68,9 +71,9 @@ string XlaCompilationCache::DebugString() { // arguments in the supplied list. string XlaCompilationCache::Signature::HumanString() const { string result = name; - for (const auto& a : arg_types) { - absl::StrAppend(&result, ",", DataTypeString(a.first), - a.second.DebugString()); + for (const auto& a : arg_shapes) { + absl::StrAppend(&result, ",", DataTypeString(a.first)); + absl::StrAppend(&result, " [", absl::StrJoin(a.second, ","), "]"); } for (const auto& v : arg_values) { @@ -81,7 +84,7 @@ string XlaCompilationCache::Signature::HumanString() const { bool XlaCompilationCache::Signature::operator==(const Signature& other) const { if (name != other.name) return false; - if (arg_types != other.arg_types) return false; + if (arg_shapes != other.arg_shapes) return false; if (arg_values.size() != other.arg_values.size()) return false; for (int i = 0; i < arg_values.size(); ++i) { @@ -97,10 +100,10 @@ bool XlaCompilationCache::Signature::operator==(const Signature& other) const { uint64 XlaCompilationCache::Signature::Hash::operator()( const XlaCompilationCache::Signature& signature) const { uint64 h = std::hash()(signature.name); - for (const auto& arg : signature.arg_types) { + for (const auto& arg : signature.arg_shapes) { h = Hash64Combine(h, std::hash()(static_cast(arg.first))); - h = Hash64Combine(h, std::hash()(arg.second.dims())); - for (int dim : arg.second.dim_sizes()) { + h = Hash64Combine(h, std::hash()(arg.second.size())); + for (int dim : arg.second) { h = Hash64Combine(h, std::hash()(dim)); } } @@ -124,7 +127,7 @@ XlaCompilationCache::BuildSignature( break; case XlaCompiler::Argument::kParameter: case XlaCompiler::Argument::kResource: - signature.arg_types.emplace_back(arg.type, arg.shape); + signature.arg_shapes.emplace_back(arg.type, arg.DimensionSizes()); break; default: return errors::InvalidArgument( diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 846d0c963dbfdf55f51120f2f138d12f5f63839b..7748b4700f39da4f952278ca6c6d2cadff4d3fb8 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -88,14 +88,16 @@ class XlaCompilationCache : public ResourceBase { xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } - string DebugString() override; + string DebugString() const override; // Describes the types, shapes and any compile-time constant arguments // to a kernel. Key that uniquely identifies a compilation output. struct Signature { string name; - std::vector> arg_types; + // List of Tensor types & shapes for compile-time constant arguments to the + // compilation, ordered by argument number. + std::vector>> arg_shapes; // List of Tensor values for compile-time constant arguments to the // compilation, ordered by argument number. Tensors must be in host memory. diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index e9770647e7ba96cc1db026d12d5f11f52ce98d35..94dc61d55fb047c0ea81d98fde24cb55387c27d7 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -83,9 +83,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { +constexpr std::array kAllXlaCpuTypes = { {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, - DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; + DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 4201ff91a89b1bee370e6a43337c51abe3bf974a..56c4220f12b54be09821eca4590df52e8e71850b 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -102,7 +102,8 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( } std::unique_ptr alloc = - absl::make_unique(); + absl::make_unique( + backend->stream_executors()[device_ordinal]); XlaDeviceAllocator* alloc_ptr = alloc.get(); state.allocators_[{backend, device_ordinal}] = std::move(alloc); return alloc_ptr; @@ -201,7 +202,8 @@ XlaDevice::XlaDevice(const SessionOptions& session_options, jit_device_name_(options.compilation_device_name), platform_(options.platform), use_multiple_streams_(options.use_multiple_streams), - shape_representation_fn_(options.shape_representation_fn) { + shape_representation_fn_(options.shape_representation_fn), + allowed_devices_(options.allowed_devices) { VLOG(1) << "Created XLA device " << options.compilation_device_name << " " << this; thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device", @@ -218,9 +220,6 @@ XlaDevice::XlaDevice(const SessionOptions& session_options, XlaDevice::~XlaDevice() { VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this; mutex_lock lock(mu_); - while (outstanding_asynchronous_operations_ > 0) { - outstanding_asynchronous_operations_cv_.wait(lock); - } if (device_context_) { device_context_->Unref(); } @@ -234,7 +233,8 @@ xla::LocalClient* XlaDevice::client() const { // TODO(b/78468222): This can fail, at least when the backend is GPU and // there is no GPU on the host. - return xla::ClientLibrary::GetOrCreateLocalClient(platform_).ValueOrDie(); + return xla::ClientLibrary::GetOrCreateLocalClient(platform_, allowed_devices_) + .ValueOrDie(); } Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { @@ -396,12 +396,6 @@ Status XlaDevice::Sync() { if (!stream) return Status::OK(); Status status = stream->BlockHostUntilDone(); - { - mutex_lock lock(mu_); - while (outstanding_asynchronous_operations_ > 0) { - outstanding_asynchronous_operations_cv_.wait(lock); - } - } TF_RETURN_IF_ERROR(status); if (!stream->ok()) { return errors::Internal("XlaDevice::Sync() failed."); @@ -410,6 +404,8 @@ Status XlaDevice::Sync() { return Status::OK(); } +// TODO(b/112409994): This is no longer necessary. Consolidate it with the +// synchronous version. void XlaDevice::Sync(const DoneCallback& done) { VLOG(1) << "XlaDevice::Sync (asynchronous)"; std::shared_ptr stream; @@ -422,14 +418,20 @@ void XlaDevice::Sync(const DoneCallback& done) { return; } + // The call to ThenEnqueueOnBackgroundThread below enqueues a host callback at + // the end of the stream, after everything that has already been enqueued + // there at this moment. When the host callback is called, everything before + // it must have already finished, and the host callback will then place the + // task below onto a background thread. (See the implementation of + // ThenEnqueueOnBackgroundThread for details.) Therefore, when the done + // callback is finally called from that background thread, we know for sure + // that everything enqueued onto the stream (i.e., the device) at this very + // moment--when ThenEnqueueOnBackgroundThread is called--will have finished. + // This achieves a device-wide sync. stream->ThenEnqueueOnBackgroundThread( - [this, stream, done](se::StreamExecutor*) { + [stream, done](se::StreamExecutor*) { tracing::ScopedActivity activity("XlaDevice::Sync::Callback", /*is_expensive=*/true); - mutex_lock lock(mu_); - while (outstanding_asynchronous_operations_ > 0) { - outstanding_asynchronous_operations_cv_.wait(lock); - } done(stream->ok() ? Status::OK() : errors::Internal("XlaDevice::Sync() failed.")); }); @@ -468,57 +470,50 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, return status; } -void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) { +void XlaDevice::SetAllowsSyncOnCompletion(bool sync_on_completion) { mutex_lock lock(mu_); sync_on_completion_ = sync_on_completion; } -bool XlaDevice::RequiresSyncOnCompletion() const { +bool XlaDevice::AllowsSyncOnCompletion() const { mutex_lock lock(mu_); return sync_on_completion_; } -XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( - XlaDevice* device) - : device_(device) { - mutex_lock lock(device_->mu_); - ++device_->outstanding_asynchronous_operations_; +void XlaDevice::SetHandleDeviceErrorCallback(std::function callback) { + mutex_lock lock(mu_); + device_error_callback_ = callback; } -XlaDevice::AsynchronousOperationHandle::~AsynchronousOperationHandle() { - if (device_) { - mutex_lock lock(device_->mu_); - --device_->outstanding_asynchronous_operations_; - device_->outstanding_asynchronous_operations_cv_.notify_all(); +Status XlaDevice::HandleDeviceError() { + std::function local_device_error_callback; + { + mutex_lock lock(mu_); + local_device_error_callback = device_error_callback_; } + if (local_device_error_callback != nullptr) { + return local_device_error_callback(); + } + return Status::OK(); } -XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( - const XlaDevice::AsynchronousOperationHandle& other) - : device_(other.device_) { - mutex_lock lock(device_->mu_); - ++device_->outstanding_asynchronous_operations_; -} - -XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( - XlaDevice::AsynchronousOperationHandle&& other) - : device_(other.device_) { - other.device_ = nullptr; -} - -XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle:: -operator=(const XlaDevice::AsynchronousOperationHandle& other) { - device_ = other.device_; - mutex_lock lock(device_->mu_); - ++device_->outstanding_asynchronous_operations_; - return *this; -} - -XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle:: -operator=(XlaDevice::AsynchronousOperationHandle&& other) { - device_ = other.device_; - other.device_ = nullptr; - return *this; +Status XlaDevice::RefreshStatus() { + std::shared_ptr stream; + { + mutex_lock lock(mu_); + stream = stream_; + } + if (!stream) { + return Status::OK(); + } + Status status = stream->RefreshStatus(); + if (!status.ok()) { + // Ignore errors from HandleDeviceError, since by definition the status is + // already non-ok, so there's nothing extra to report if HandleDeviceError + // itself returns an error. + HandleDeviceError().IgnoreError(); + } + return status; } XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index c8bb276cdb9673fdcba4cc15a9f33ecd3ae96dbb..977f5f5cf151d979d025c2966012445af04fc502 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -24,7 +24,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#include +#include "absl/types/optional.h" #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -123,6 +125,11 @@ class XlaDevice : public LocalDevice { // If padded_shape_fn is empty, a default implementation that returns // the logical on-device shape without padding is used. PaddedShapeFn padded_shape_fn; + + // Set of devices to use. This controls which of the devices on the given + // platform will have resources allocated. For GPUs this will be + // filled from visible_gpu_devices list from session configuration. + absl::optional> allowed_devices; }; // Creates a new XLA Device. @@ -160,35 +167,16 @@ class XlaDevice : public LocalDevice { Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); // Instructs this XlaDevice to return 'sync_on_completion' for - // RequiresSyncOnCompletion(). - void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); - - bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); + // AllowsSyncOnCompletion(). + void SetAllowsSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); + bool AllowsSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); - // A simple RAII handle. On construction the device's - // outstanding_asynchronous_operations_ field is incremented; on destruction - // it is decremented. - class AsynchronousOperationHandle { - public: - AsynchronousOperationHandle(XlaDevice* device); - ~AsynchronousOperationHandle(); - AsynchronousOperationHandle(const AsynchronousOperationHandle& other); - AsynchronousOperationHandle(AsynchronousOperationHandle&& other); - AsynchronousOperationHandle& operator=( - const AsynchronousOperationHandle& other); - AsynchronousOperationHandle& operator=(AsynchronousOperationHandle&& other); + // Installs an error handling callback when RefreshStatus sees !status.ok(). + void SetHandleDeviceErrorCallback(std::function callback); - private: - XlaDevice* device_ = nullptr; - }; - - AsynchronousOperationHandle CreateAsynchronousOperationHandle() { - return AsynchronousOperationHandle(this); - } + Status RefreshStatus() override LOCKS_EXCLUDED(mu_); private: - friend class AsynchronousOperationHandle; - xla::LocalClient* client() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -202,6 +190,9 @@ class XlaDevice : public LocalDevice { static Status GetMetadataFromDevice(DeviceBase* device, const XlaDevice::Metadata** metadata); + // Handles error when RefreshStatus sees !status.ok(). + Status HandleDeviceError(); + mutable mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; @@ -248,14 +239,17 @@ class XlaDevice : public LocalDevice { // Thread pool used for running closures std::unique_ptr thread_pool_; - // True if the device requires XlaDevice::Sync to be called on completion + // True if the device allows XlaDevice::Sync to be called on completion // regardless of status. - bool sync_on_completion_ GUARDED_BY(mu_) = false; + bool sync_on_completion_ GUARDED_BY(mu_) = true; + + // A callback that will be invoked when RefreshStatus sees a status error. + std::function device_error_callback_ GUARDED_BY(mu_); - // Count of outstanding asynchronous operations which must be zero on Sync() - // completion. - int64 outstanding_asynchronous_operations_ GUARDED_BY(mu_) = 0; - condition_variable outstanding_asynchronous_operations_cv_; + // Set of devices to use. This controls which of the devices on the given + // platform will have resources allocated. For GPUs this will be + // filled from visible_gpu_devices list from session configuration. + absl::optional> allowed_devices_; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 6e6532731e64bd42ee56aa719748988f321e0f17..05b9c511866d3ca48ec3519bee8a4dbf6086f6ac 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -29,7 +29,10 @@ limitations under the License. namespace tensorflow { // The allocator used for Tensors assigned to the XLA device. -XlaDeviceAllocator::XlaDeviceAllocator() {} +XlaDeviceAllocator::XlaDeviceAllocator( + stream_executor::StreamExecutor* stream_executor) + : stream_executor_(stream_executor) {} + XlaDeviceAllocator::~XlaDeviceAllocator() = default; string XlaDeviceAllocator::Name() { return "xla"; } @@ -48,7 +51,21 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { delete XlaTensor::FromOpaquePointer(ptr); } -void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } +absl::optional XlaDeviceAllocator::GetStats() { + absl::optional se_stats = + stream_executor_->GetAllocatorStats(); + if (!se_stats) { + return absl::nullopt; + } + + tensorflow::AllocatorStats tf_stats; + tf_stats.num_allocs = se_stats->num_allocs; + tf_stats.bytes_in_use = se_stats->bytes_in_use; + tf_stats.peak_bytes_in_use = se_stats->peak_bytes_in_use; + tf_stats.largest_alloc_size = se_stats->largest_alloc_size; + tf_stats.bytes_limit = se_stats->bytes_limit; + return tf_stats; +} XlaDeviceContext::XlaDeviceContext( std::shared_ptr compute_stream, @@ -79,6 +96,13 @@ XlaDeviceContext::XlaDeviceContext( } } +void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor, + Device* device, + Tensor* output_tensor, + StatusCallback done) const { + done(errors::Unimplemented("XLA->XLA same-device copies not implemented.")); +} + void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, @@ -124,7 +148,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, xla::ShapeUtil::MakeShape(shape.element_type(), xla::AsInt64Slice(shape.dimensions()))); - VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " + VLOG(2) << "Transfer to device as literal: " << literal.ToString() << " " << xla_tensor->shaped_buffer().ToString(); if (UseMultipleStreams() && !transfer_manager_->CanShapedBufferBeAccessedNow( @@ -207,7 +231,7 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, device_to_host_stream_.get(), xla_tensor->shaped_buffer(), literal, [ref, xla_tensor, done](xla::Status status) { done([&]() -> Status { - VLOG(1) << "Transfer from device as literal: " + VLOG(2) << "Transfer from device as literal: " << xla_tensor->shaped_buffer().ToString(); return status; }()); diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 1e18df197a2dd65590c5181b4dae4481dca36641..1ce64ad323b4827adc2f4d48841315fbde43e532 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -34,14 +34,18 @@ namespace tensorflow { // empty, XlaTensor. class XlaDeviceAllocator : public Allocator { public: - XlaDeviceAllocator(); + XlaDeviceAllocator(se::StreamExecutor* stream_executor); ~XlaDeviceAllocator() override; string Name() override; void* AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void* ptr) override; - void GetStats(AllocatorStats* stats) override; + absl::optional GetStats() override; + + private: + // The stream executor of the device. + se::StreamExecutor* stream_executor_; }; // Helper class for managing data transfers between host and XLA devices. @@ -62,6 +66,9 @@ class XlaDeviceContext : public DeviceContext { void CopyDeviceTensorToCPU(const Tensor* device_tensor, absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override; xla::LocalClient* client() const { return client_; } se::Stream* stream() const { return stream_.get(); } diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 927f983ba9ef23c8509523f42366c0c89c29db9f..09e04d22def9c39f45c2737c1d4a5e7787e3fdc0 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/kernels/control_flow_ops.h" #include "tensorflow/core/kernels/data/generator_dataset_op.h" #include "tensorflow/core/kernels/data/iterator_ops.h" +#include "tensorflow/core/kernels/data/optional_ops.h" #include "tensorflow/core/kernels/data/prefetch_dataset_op.h" #include "tensorflow/core/kernels/fifo_queue.h" #include "tensorflow/core/kernels/function_ops.h" @@ -241,6 +242,8 @@ class XlaAssignVariableOp : public OpKernel { data::AnonymousIteratorHandleOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ data::IteratorGetNextOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \ + data::IteratorGetNextAsOptionalOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \ data::IteratorGetNextSyncOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ @@ -251,6 +254,15 @@ class XlaAssignVariableOp : public OpKernel { .Device(DEVICE) \ .HostMemory("string_handle"), \ data::IteratorFromStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE), \ + data::OptionalNoneOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE), \ + data::OptionalFromValueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("OptionalHasValue").Device(DEVICE).HostMemory("has_value"), \ + data::OptionalHasValueOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE), \ + data::OptionalGetValueOp); \ REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ .Device(DEVICE) \ .HostMemory("output") \ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 0191315a66f4d331e54fadc9dc6a073a05fd67ef..b29f6a009b9e9fdba76ac55386a4bec2f339cc0e 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -29,6 +29,30 @@ limitations under the License. namespace tensorflow { +// Returns a set containing the device ids contained in visible_device_list or +// nullopt if it is empty. It returns error in case of malformed configuration +// string. +static xla::StatusOr>> ParseVisibleDeviceList( + const string& visible_device_list) { + std::set gpu_ids; + if (visible_device_list.empty()) { + return {{absl::nullopt}}; + } + const std::vector visible_devices = + absl::StrSplit(visible_device_list, ','); + for (const string& platform_gpu_id_str : visible_devices) { + int32 platform_gpu_id; + if (!absl::SimpleAtoi(platform_gpu_id_str, &platform_gpu_id)) { + return errors::InvalidArgument( + "Could not parse entry in 'visible_device_list': '", + platform_gpu_id_str, + "'. visible_device_list = ", visible_device_list); + } + gpu_ids.insert(platform_gpu_id); + } + return {{gpu_ids}}; +} + class XlaGpuDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, @@ -57,33 +81,16 @@ Status XlaGpuDeviceFactory::CreateDevices( } string allowed_gpus = session_options.config.gpu_options().visible_device_list(); - std::set gpu_ids; - int num_visible_devices = platform.ValueOrDie()->VisibleDeviceCount(); - if (allowed_gpus.empty()) { - for (int i = 0; i < num_visible_devices; ++i) { - gpu_ids.insert(i); - } - } else { - // For loop below is copied from gpu/gpu_device.cc. It validates - // the visible_device_list and populates gpu_ids set. - const std::vector visible_devices = - absl::StrSplit(allowed_gpus, ','); - for (const string& platform_gpu_id_str : visible_devices) { - int32 platform_gpu_id; - if (!absl::SimpleAtoi(platform_gpu_id_str, &platform_gpu_id)) { - return errors::InvalidArgument( - "Could not parse entry in 'visible_device_list': '", - platform_gpu_id_str, "'. visible_device_list = ", allowed_gpus); - } - if (platform_gpu_id < 0 || platform_gpu_id >= num_visible_devices) { - return errors::InvalidArgument( - "'visible_device_list' listed an invalid GPU id '", platform_gpu_id, - "' but visible device count is ", num_visible_devices); - } - gpu_ids.insert(platform_gpu_id); + absl::optional> gpu_ids = + ParseVisibleDeviceList(allowed_gpus).ValueOrDie(); + if (!gpu_ids) { + gpu_ids.emplace(); + // Fill the gpu_ids set with all devices if config string is empty. + for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) { + gpu_ids->insert(i); } } - for (int i : gpu_ids) { + for (int i : *gpu_ids) { XlaDevice::Options options; options.platform = platform.ValueOrDie(); options.device_name_prefix = name_prefix; @@ -91,6 +98,7 @@ Status XlaGpuDeviceFactory::CreateDevices( options.device_ordinal = i; options.compilation_device_name = DEVICE_GPU_XLA_JIT; options.use_multiple_streams = true; + options.allowed_devices = gpu_ids; auto device = absl::make_unique(session_options, options); Status status = device->UseGpuDeviceInfo(); diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 4007309ed1c57b663dca5bac0df11260bf1327f3..e1a582406153d2af447fa9d4ebcaf0bf0842b132 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -26,9 +26,9 @@ namespace tensorflow { const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER"; const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT"; -constexpr std::array kExecAllTypes = { +constexpr std::array kExecAllTypes = { {DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, - DT_BOOL, DT_BFLOAT16}}; + DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; class XlaInterpreterDeviceFactory : public DeviceFactory { public: diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 3b0bda4caa161a7561a3098b89420329998ff8a7..c64981053fad2dbf1e8bcd623a940ded8b4d9150 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -237,7 +237,7 @@ void XlaComputationLaunchContext::PopulateInputs( const xla::Shape on_device_shape = client_->backend().transfer_manager()->HostShapeToDeviceShape(shape); - if (xla::ShapeUtil::IsTuple(on_device_shape)) { + if (on_device_shape.IsTuple()) { const XlaTensor* xla_tensor = XlaTensor::FromTensor(t); CHECK(xla_tensor && xla_tensor->has_shaped_buffer()); arg_ptrs_[i] = const_cast(&xla_tensor->shaped_buffer()); @@ -274,7 +274,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( // If the on-host-shape isn't a tuple, create a new single-element tuple // buffer with a nullptr root index table. This allows the code below to treat // output as a tuple unconditionally. - if (!xla::ShapeUtil::IsTuple(output.on_host_shape())) { + if (!output.on_host_shape().IsTuple()) { ShapedBuffer nontuple_buffer = output.release(); ShapedBuffer buffer( xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}), @@ -377,7 +377,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( } if (VLOG_IS_ON(3)) { - VLOG(3) << ctx->mutable_output(i)->DebugString(); + VLOG(3) << ctx->mutable_output(i)->DeviceSafeDebugString(); } } diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 093b61629cd0b04d5d8488139b8d7262b739f86d..7c1e0daf0b7b418530367cb80fbd18b93e8e5f5e 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -72,7 +72,7 @@ py_test( tf_xla_py_test( name = "adadelta_test", - size = "large", + size = "medium", srcs = ["adadelta_test.py"], deps = [ ":xla_test", @@ -230,6 +230,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", + "//tensorflow/python:standard_ops", ], ) @@ -242,9 +243,33 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:framework", + "//tensorflow/python:map_fn", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + +tf_xla_py_test( + name = "self_adjoint_eig_op_test", + size = "medium", + srcs = ["self_adjoint_eig_op_test.py"], + # TODO(kuny): remove it after b/124377352 is fixed. + disabled_backends = [ + "cpu", + "gpu", + "cpu_ondemand", + ], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:map_fn", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "@absl_py//absl/testing:parameterized", ], ) @@ -277,10 +302,9 @@ tf_xla_py_test( ], ) -# This test is large because occasionally the cpu test is long for testConcatLargeNumberOfTensors tf_xla_py_test( name = "concat_ops_test", - size = "large", + size = "medium", srcs = ["concat_ops_test.py"], deps = [ ":xla_test", @@ -406,7 +430,7 @@ tf_xla_py_test( tf_xla_py_test( name = "eager_test", - size = "large", + size = "medium", srcs = ["eager_test.py"], deps = [ ":xla_test", @@ -677,6 +701,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", + "//tensorflow/python:standard_ops", ], ) @@ -826,6 +851,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python:standard_ops", "//tensorflow/python:stateless_random_ops", ], ) @@ -1188,11 +1214,18 @@ tf_xla_py_test( tf_xla_py_test( name = "quantized_ops_test", - size = "small", + size = "medium", srcs = ["quantized_ops_test.py"], + disabled_backends = [ + "cpu", + "cpu_ondemand", + ], deps = [ ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", + "//tensorflow/python:bitwise_ops", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py index b7b7fda293b69d6f0cec61d0d234277636a3670d..6cf16cc07ff503c4f3e008cfb720224abe5e9166 100644 --- a/tensorflow/compiler/tests/adadelta_test.py +++ b/tensorflow/compiler/tests/adadelta_test.py @@ -32,10 +32,18 @@ class AdadeltaOptimizerTest(xla_test.XLATestCase): def testBasic(self): num_updates = 4 # number of ADADELTA steps to perform + if "CPU" in self.device: + # To avoid timeout on CPU. + all_grad = [0.2, 0.01] + all_lr = [1.0, 0.1] + else: + all_grad = [0.2, 0.1, 0.01] + all_lr = [1.0, 0.5, 0.1] + for dtype in self.float_types: with self.cached_session(), self.test_scope(): - for grad in [0.2, 0.1, 0.01]: - for lr in [1.0, 0.5, 0.1]: + for grad in all_grad: + for lr in all_lr: var0_init = [1.0, 2.0] var1_init = [3.0, 4.0] var0 = resource_variable_ops.ResourceVariable( diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 9a5423c1b2a5df7880453cbb328f6a8174066255..c829c50b5518b29c96c0b0117a6cd143911bd1fc 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -311,6 +311,30 @@ class BinaryOpsTest(xla_test.XLATestCase): dtype(7), expected=np.array([[-6], [-5]], dtype=dtype)) + if dtype in [np.float32, np.float64]: + x = np.array([ + -0.0, 0.0, -0.0, +0.0, np.inf, np.inf, -np.inf, -np.inf, 2.0, 2.0, + 1.0 + ], + dtype=dtype) + y = np.array( + [-0.0, 0.0, +0.0, -0.0, 1.0, -1.0, 1.0, -1.0, 2.0, 1.0, 2.0], + dtype=dtype) + expected = np.nextafter(x, y) + + # We use assertAllEqual to expose any bugs hidden by relative or + # absolute error tolerances. + def NextAfterEqualityTest(result, expected, rtol): + del rtol + return self.assertAllEqual(result, expected) + + self._testBinary( + math_ops.nextafter, + x, + y, + expected=expected, + equality_test=NextAfterEqualityTest) + # min/max not supported for complex if dtype not in self.complex_types | {np.uint8, np.int8}: self._testBinary( @@ -400,7 +424,7 @@ class BinaryOpsTest(xla_test.XLATestCase): def testComplexOps(self): for dtype in self.complex_types: - ctypes = {np.complex64: np.float32} + ctypes = {np.complex64: np.float32, np.complex128: np.float64} self._testBinary( math_ops.complex, np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]), diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 447a7de2cb6526a5dcf7789d4f2bffb5e733e8c0..ed580f95b6c2f57dfdf46cfcd64cabb452980c5d 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -5,6 +5,7 @@ load("//tensorflow/compiler/tests:plugin.bzl", "plugins") load( "//tensorflow/core:platform/default/build_config_root.bzl", "tf_cuda_tests_tags", + "tf_exec_compatible_with", ) def all_backends(): @@ -64,7 +65,7 @@ def tf_xla_py_test( if backend == "cpu": backend_args += [ "--test_device=XLA_CPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_COMPLEX128", ] elif backend == "gpu": backend_args += [ @@ -84,6 +85,7 @@ def tf_xla_py_test( else: fail("Unknown backend {}".format(backend)) + test_tags = tags + backend_tags native.py_test( name = test_name, srcs = srcs, @@ -92,7 +94,8 @@ def tf_xla_py_test( main = "{}.py".format(name) if main == None else main, data = data + backend_data, deps = deps + backend_deps, - tags = tags + backend_tags, + tags = test_tags, + exec_compatible_with = tf_exec_compatible_with({"tags": test_tags}), **kwargs ) test_names.append(test_name) diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index 5d5e486f616937601214aa169a4c329ab78932c8..eec69ea7d2d9af9ff570f927fb25b668ccce2b97 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -119,7 +119,7 @@ class CategoricalTest(xla_test.XLATestCase): def testSamplingCorrectness(self): np.random.seed(1618) # Make it reproducible. - num_samples = 21000 + num_samples = 40000 rand_probs = np.random.dirichlet([1., 1., 2., 3.]) rand_probs2 = np.random.dirichlet([1., 4., 5.], size=3) # batched diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 2187f57960f80300d631bdc7eb8fe5e9c8dddeea..76750decd2963ea12680a46d7340f48e8b011fa9 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -294,6 +294,9 @@ class ConcatTest(xla_test.XLATestCase): # The purpose of this is to ensure that XLA on GPU will not run out of memory # with too many arguments. def testConcatLargeNumberOfTensors(self): + if "CPU" in self.device: + self.skipTest("This test can time out on CPU, so we will just allow " + "other backends to catch this specific error.") with self.cached_session(): with self.test_scope(): for concat_dim in range(2): diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index bf5ea7b1fb6fb3c774c4db20d059f131990d20d3..b7d08df9f7d144b71fd0b09535e10b8f596ea6ca 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -72,7 +72,7 @@ class DenseLayerTest(test.TestCase): x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) y = layers.dense(x, 3) - self.evaluate(variables.initialize_all_variables()) + self.evaluate(variables.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, @@ -97,7 +97,7 @@ class DenseLayerTest(test.TestCase): with jit_scope(): y = layers.dense(x, 3) - self.evaluate(variables.initialize_all_variables()) + self.evaluate(variables.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, @@ -126,7 +126,7 @@ class DenseLayerTest(test.TestCase): with jit_scope(): y = layers.dense(x, 3) - self.evaluate(variables.initialize_all_variables()) + self.evaluate(variables.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 174bfa9efbcd7dcb4f895237eb01c17bc4a3a6b4..90146e6b27ca31304a2549ec247412341efe390c 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -350,8 +350,13 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): self._CompareBackpropInput(input_size, filter_size, output_size, stride, padding) - def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes, - stride, padding): + def _CompareBackpropFilter(self, + input_sizes, + filter_sizes, + output_sizes, + stride, + padding, + data_format="NHWC"): x0 = np.random.rand(*input_sizes).astype(np.float32) x2 = np.random.rand(*output_sizes).astype(np.float32) @@ -360,13 +365,30 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): t0 = array_ops.placeholder(np.float32, shape=input_sizes) t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) t2 = array_ops.placeholder(np.float32, shape=output_sizes) + native_t0 = t0 + native_t2 = t2 + strides = [1, stride, stride, 1] + if use_xla: + if data_format == "NCHW": + # Transpose from NWHC input to NCHW + # Ex. [4, 5, 5, 48] to [4, 48, 5, 5] + native_t0 = array_ops.transpose(t0, [0, 3, 1, 2]) + native_t2 = array_ops.transpose(t2, [0, 3, 1, 2]) + strides = [1, 1, stride, stride] with self.test_scope(): backprop = nn_ops.depthwise_conv2d_native_backprop_filter( - t0, t1, t2, strides=[1, stride, stride, 1], padding=padding) + native_t0, + t1, + native_t2, + strides=strides, + padding=padding, + data_format=data_format) else: + # For CPU, the format NCHW is not supported. Therefore we always use + # NHWC here. backprop = nn_ops.depthwise_conv2d_native_backprop_filter( - t0, t1, t2, strides=[1, stride, stride, 1], padding=padding) + native_t0, t1, native_t2, strides=strides, padding=padding) ret = backprop.eval({t0: x0, t2: x2}) self.assertShapeEqual(ret, backprop) return ret @@ -379,11 +401,24 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): for index, (input_size, filter_size, output_size, stride, padding) in enumerate(ConfigsToTest()): print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:", - input_size, "*", filter_size, "stride:", stride, "padding:", - padding) + input_size, "*", filter_size, "producing output", output_size, + "stride:", stride, "padding:", padding) self._CompareBackpropFilter(input_size, filter_size, output_size, stride, padding) + def testDepthwiseConv2DFilterGradFormatNCHWCompare(self): + for index, (input_size, filter_size, output_size, stride, + padding) in enumerate(ConfigsToTest()): + print("Testing DepthwiseConv2DFilterGradFormatNCHWCompare,", index, + "th config:", input_size, "*", filter_size, "producing output", + output_size, "stride:", stride, "padding:", padding) + self._CompareBackpropFilter( + input_size, + filter_size, + output_size, + stride, + padding, + data_format="NCHW") if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 2af32b537ba53723370faf81aebf308a465718c7..632eccbb097b4e84f10f926e89d7fa439c8a38cd 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -31,7 +32,9 @@ from tensorflow.python.framework import ops from tensorflow.python.layers import convolutional from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -463,7 +466,7 @@ class EagerFunctionTest(xla_test.XLATestCase): def f(x, y): return x[0::2, y:, ...] - x = array_ops.ones([2, 3, 4]) + x = array_ops.ones([2, 3, 4], dtype=dtypes.float32) y = array_ops.ones([], dtype=dtypes.int32) with backprop.GradientTape() as tape: tape.watch(x) @@ -479,15 +482,15 @@ class EagerFunctionTest(xla_test.XLATestCase): @function.defun def times_two(x): - return 2 * x + return 2. * x @function.defun def two_x_plus_1(x): - return times_two(x) + 1 + return times_two(x) + 1. - x = constant_op.constant([2, 3, 4]) + x = constant_op.constant([2., 3., 4.]) y = two_x_plus_1(x) - self.assertAllEqual([5, 7, 9], y.numpy()) + self.assertAllEqual([5., 7., 9.], y.numpy()) def testNestedDefunWithVariable(self): with self.test_scope(): @@ -506,7 +509,7 @@ class EagerFunctionTest(xla_test.XLATestCase): x = constant_op.constant(3.0) y = f(x) - self.assertEqual(75, y.numpy()) + self.assertEqual(75.0, y.numpy()) def testNestedDefunInGradientTape(self): with self.test_scope(): @@ -555,6 +558,71 @@ class EagerFunctionTest(xla_test.XLATestCase): self.assertEqual(9, dy_v0.numpy()) self.assertEqual(15, dy_v1.numpy()) + def testWhileInDefun(self): + with self.test_scope(): + @def_function.function + def f(start): + c = lambda x: math_ops.less(x, 13.0) + b = lambda x: math_ops.add(x, 1.0) + return control_flow_ops.while_loop(c, b, [start]) + + y = f(constant_op.constant(3.0)) + self.assertEqual(13.0, y.numpy()) + + def testAutoGraphWhileInDefun(self): + with self.test_scope(): + @def_function.function + def f(start): + x = start + while x < 13.0: + x += 1.0 + return x + + y = f(constant_op.constant(3.0)) + self.assertEqual(13.0, y.numpy()) + + def testCondInDefun(self): + with self.test_scope(): + @def_function.function + def f(pred, value): + fn1 = lambda: math_ops.add(value, 1.0) + fn2 = lambda: math_ops.subtract(value, 1.0) + return control_flow_ops.cond(pred, fn1, fn2) + + plus_one = f(constant_op.constant(True), constant_op.constant(10.0)) + minus_one = f(constant_op.constant(False), constant_op.constant(10.0)) + self.assertEqual(11.0, plus_one.numpy()) + self.assertEqual(9.0, minus_one.numpy()) + + def testAutoGraphCondInDefun(self): + with self.test_scope(): + @def_function.function + def f(pred, value): + if pred: + return value + 1.0 + else: + return value - 1.0 + + plus_one = f(constant_op.constant(True), constant_op.constant(10.0)) + minus_one = f(constant_op.constant(False), constant_op.constant(10.0)) + self.assertEqual(11.0, plus_one.numpy()) + self.assertEqual(9.0, minus_one.numpy()) + + def testScanInDefun(self): + with self.test_scope(): + elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='data') + v = constant_op.constant(2.0, name='v') + + @def_function.function + def f(y): + # pylint: disable=unnecessary-lambda + return functional_ops.scan( + lambda a, x: math_ops.multiply(a, x), y, initializer=v) + # pylint: enable=unnecessary-lambda + + r = f(elems) + self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) + class ExcessivePaddingTest(xla_test.XLATestCase): """Test that eager execution works with TPU flattened tensors. diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 374942a0b339b816944ea5529e4f84134b60017b..56a8e1b1667f154f6cec475ee0f4f8b308121c09 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -191,6 +191,20 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 + + # The TensorFlow FusedBatchNormGrad training operation takes two inputs with + # implementation defined values. In theory the only correct value these + # inputs are the corresponding reserve_space_{1|2} outputs from the + # FusedBatchNorm training operation. However, in practice, we rely on the + # first one being mean on {C|G}PU, and the second one being variance on CPU + # and inverse(sqrt(variance + epsilon)) on GPU (we test this assumption + # separately). + reserve_space_1_val = mean_val + if self.device == "XLA_GPU": + reserve_space_2_val = np.reciprocal(np.sqrt(var_val + epsilon)) + else: + reserve_space_2_val = var_val + data_format_src = "NHWC" grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src) @@ -207,18 +221,26 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): np.float32, shape=x_val_converted.shape, name="grad") x = array_ops.placeholder( np.float32, shape=x_val_converted.shape, name="x") - mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") - var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") + reserve_space_1 = array_ops.placeholder( + np.float32, shape=scale_shape, name="reserve_space_1") + reserve_space_2 = array_ops.placeholder( + np.float32, shape=scale_shape, name="reserve_space_2") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format=data_format, is_training=True) + grad, + x, + scale, + reserve_space_1, + reserve_space_2, + data_format=data_format, + is_training=True) grad_x_val, grad_scale_val, grad_offset_val = sess.run( [grad_x, grad_scale, grad_offset], { grad: grad_val_converted, x: x_val_converted, - mean: mean_val, - var: var_val, + reserve_space_1: reserve_space_1_val, + reserve_space_2: reserve_space_2_val, scale: scale_val }) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 0e2d840418156d825e2d141018e49f42374c8fee..42e688174fce9e939feb09e1767ebab31e30a6ee 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -403,6 +403,117 @@ class AdjustSaturationTest(xla_test.XLATestCase): self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) +class ResizeNearestNeighborTest(xla_test.XLATestCase): + # TODO(ilch): Wrap each test with `for dtype in self.float_types:` + # Some work to understand how that should be done was presented here: + # cl/227850213 + + def _assertForwardOpMatchesExpected(self, + image_np, + target_shape, + expected=None, + large_tolerance=False, + align_corners=True): + if expected is None: + self.fail("expected must be specified") + with self.cached_session() as sess, self.test_scope(): + image = array_ops.placeholder(image_np.dtype) + resized = gen_image_ops.resize_nearest_neighbor( + image, target_shape, align_corners=align_corners) + out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) + if large_tolerance: + self.assertAllClose( + expected[np.newaxis, :, :, np.newaxis], out, rtol=2e-4, atol=2e-4) + else: + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + + def testAlignCorners2x2To1x1(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=np.float32), [1, 1], + expected=np.array([[1]], dtype=np.float32)) + + def testAlignCorners1x1To2x2(self): + self._assertForwardOpMatchesExpected( + np.array([[1]], dtype=np.float32), [2, 2], + expected=np.array([[1, 1], [1, 1]], dtype=np.float32)) + + def testAlignCorners1x1To3x3(self): + self._assertForwardOpMatchesExpected( + np.array([[1]], dtype=np.float32), [3, 3], + expected=np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]], dtype=np.float32)) + + def testAlignCorners2x2To3x3(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=np.float32), [3, 3], + expected=np.array([[1, 2, 2], [3, 4, 4], [3, 4, 4]], dtype=np.float32)) + + def testAlignCorners2x2To4x4(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=np.float32), [4, 4], + expected=np.array( + [[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]], + dtype=np.float32), large_tolerance=True) + + def testAlignCorners3x3To2x2(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [2, 2], + expected=np.array([[1, 3], [7, 9]], dtype=np.float32)) + + def testAlignCorners4x4To3x3(self): + self._assertForwardOpMatchesExpected( + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=np.float32), [3, 3], + expected=np.array([[1, 3, 4], [9, 11, 12], [13, 15, 16]], + dtype=np.float32)) + + def testAlignCorners3x3To4x4(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [4, 4], + expected=np.array( + [[1, 2, 2, 3], [4, 5, 5, 6], [4, 5, 5, 6], [7, 8, 8, 9]], + dtype=np.float32)) + + def testAlignCorners3x3To6x6(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [6, 6], + expected=np.array( + [[1, 1, 2, 2, 3, 3], [1, 1, 2, 2, 3, 3], [4, 4, 5, 5, 6, 6], + [4, 4, 5, 5, 6, 6], [7, 7, 8, 8, 9, 9], [7, 7, 8, 8, 9, 9]], + dtype=np.float32)) + + def testAlignCorners3x3To9x9(self): + # The expected matrix might look uneven in terms of how many of each number + # there is, but this is an artifact of doing the dilation and convolution + # iteratively. The behavior is less esoteric in the 3x3To12x12 case below. + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [9, 9], + expected=np.array( + [[1, 2, 2, 2, 2, 3, 3, 3, 3], [4, 5, 5, 5, 5, 6, 6, 6, 6], + [4, 5, 5, 5, 5, 6, 6, 6, 6], [4, 5, 5, 5, 5, 6, 6, 6, 6], + [4, 5, 5, 5, 5, 6, 6, 6, 6], [7, 8, 8, 8, 8, 9, 9, 9, 9], + [7, 8, 8, 8, 8, 9, 9, 9, 9], [7, 8, 8, 8, 8, 9, 9, 9, 9], + [7, 8, 8, 8, 8, 9, 9, 9, 9]], + dtype=np.float32)) + + def testAlignCorners3x3To12x12(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [12, 12], + expected=np.array([[1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9]], + dtype=np.float32)) + + class ResizeBilinearTest(xla_test.XLATestCase): def _assertForwardOpMatchesExpected(self, @@ -444,14 +555,14 @@ class ResizeBilinearTest(xla_test.XLATestCase): self.assertAllCloseAccordingToType(expected[np.newaxis, :, :, np.newaxis], out) - def testAlignCorners1x2To3x2(self): + def testAlignCorners1x2To3x3(self): for dtype in self.float_types: self._assertForwardOpMatchesExpected( np.array([[1, 2]], dtype=dtype), [3, 3], expected=np.array([[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]], dtype=np.float32)) - def testAlignCorners1x2To3x2Grad(self): + def testAlignCorners1x2To3x3Grad(self): for dtype in self.float_types: self._assertBackwardOpMatchesExpected( np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index c61965b97fc142ce452cf28def8c937f692d2f84..0eec070a906670ff36c772edda22f8291b5b734a 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -167,6 +167,11 @@ class MatrixBandPartTest(xla_test.XLATestCase, parameterized.TestCase): }, ) def testMatrixBandPart(self, batch_shape, rows, cols): + # TODO(b/125505881): Disabled due to LLVM backend crash. + if self.device == 'XLA_CPU' and cols == 7 and rows == 1 and batch_shape == [ + 1, 3, 2 + ]: + pass for dtype in self.float_types: with self.cached_session(): mat = np.ones(batch_shape + [rows, cols]).astype(dtype) diff --git a/tensorflow/compiler/tests/plugin.bzl b/tensorflow/compiler/tests/plugin.bzl index fbc8781a3e59faecf985cde5114bf56a041c4be0..46a854d1459b7ea9d9fe3cf7689faee557c2cf84 100644 --- a/tensorflow/compiler/tests/plugin.bzl +++ b/tensorflow/compiler/tests/plugin.bzl @@ -18,13 +18,12 @@ # git update-index --assume-unchanged tensorflow/compiler/tests/plugin.bzl plugins = { - #"example": { - # "device":"XLA_MY_DEVICE", - # "types":"DT_FLOAT,DT_HALF,DT_INT32", - # "tags":[], - # "args":["--disabled_manifest=tensorflow/compiler/plugin/example/disabled_manifest.txt"], - # "data":["//tensorflow/compiler/plugin/example:disabled_manifest.txt"], - # "deps":[], - #}, + #"example": { + # "device":"XLA_MY_DEVICE", + # "types":"DT_FLOAT,DT_HALF,DT_INT32", + # "tags":[], + # "args":["--disabled_manifest=tensorflow/compiler/plugin/example/disabled_manifest.txt"], + # "data":["//tensorflow/compiler/plugin/example:disabled_manifest.txt"], + # "deps":[], + #}, } - diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py index 80c338513bc9ff6b8e56c5ad6b904af9e06a3715..cd9b728ab314d29e4eb585e00a9131024ea3a207 100644 --- a/tensorflow/compiler/tests/quantized_ops_test.py +++ b/tensorflow/compiler/tests/quantized_ops_test.py @@ -18,11 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest @@ -44,5 +49,55 @@ class QuantizedOpsTest(xla_test.XLATestCase): self.assertAllEqual(value, expected) +class DeuantizedOpsTest(xla_test.XLATestCase): + + def pack_uint8_r2_to_uint32(self, test_input): + num_rows, num_columns = test_input.get_shape().as_list() + num_output_columns = int(math.ceil(num_columns / 4.0)) + padding_input = array_ops.pad( + math_ops.cast(test_input, dtype=dtypes.uint8), + constant_op.constant([[ + 0, + 0, + ], [0, num_output_columns * 4 - num_columns]])) + output = array_ops.zeros([num_rows, num_output_columns], + dtype=dtypes.uint32) + num_elements_per_pack = 4 + shift_bits = 8 + + iota_r1 = math_ops.range(num_output_columns * num_elements_per_pack) + + for p in range(num_elements_per_pack): + selected_index = math_ops.equal( + math_ops.mod(iota_r1, num_elements_per_pack), p) + gather_index = array_ops.boolean_mask(iota_r1, selected_index) + gathered_input = array_ops.gather(padding_input, gather_index, axis=1) + total_shift_bits = shift_bits * (num_elements_per_pack - p - 1) + left_shift_input = bitwise_ops.left_shift( + math_ops.cast(gathered_input, dtype=dtypes.uint32), total_shift_bits) + output = bitwise_ops.bitwise_or(output, left_shift_input) + return output + + def testDequantizeQuint8(self): + num_rows = 100 + num_columns = 3547 + random_input = np.random.normal(128.0, 10.0, [num_rows, num_columns]) + with self.cached_session() as session: + with ops.device("CPU"): + test_input = ops.convert_to_tensor(random_input, dtype=dtypes.float32) + transposed_input = array_ops.transpose(test_input, [1, 0]) + quantized_input = array_ops.quantize(transposed_input, 0.0, 255.0, + dtypes.quint8) + packed_input = self.pack_uint8_r2_to_uint32(quantized_input.output) + with self.test_scope(): + transposed_quantized_output = xla.dequantize(packed_input, 0.0, 255.0, + "MIN_COMBINED", True) + quantized_output = array_ops.slice(transposed_quantized_output, [0, 0], + [num_rows, num_columns]) + + value = session.run(quantized_output) + self.assertAllClose(value, random_input, 1.0) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 97ffad34c00b8ec16eb1ec109ba5d980e0ce673d..34f2465ba63f235f893db9dd6930ac252c3e7226 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -122,8 +122,8 @@ class RandomOpsTest(xla_test.XLATestCase): beta = (b - mu) / sigma z = normal_cdf(beta) - normal_cdf(alpha) - self.assertTrue((y >= a).sum() == count) - self.assertTrue((y <= b).sum() == count) + self.assertEqual((y >= a).sum(), count) + self.assertEqual((y <= b).sum(), count) # For more information on these calculations, see: # Burkardt, John. "The Truncated Normal Distribution". diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index d23fd125163d1afe8c7fd5e008d4b617ff4b2874..1521cc760b85b176acb27c1489640e92ef90e247 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -63,6 +63,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -80,6 +81,7 @@ int64 tf_xla_random_seed = 0; int32 tf_xla_test_repetitions = 20; int64 tf_xla_max_tensor_size = 10000LL; string* tf_xla_test_device_ptr; // initial value set in main() +string* tf_xla_reference_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; string LocalDeviceToFullDeviceName(const string& device) { @@ -321,6 +323,9 @@ class OpTest : public ::testing::Test { // for use as reduction indices. Tensor RandomReductionIndices(int rank); + // Returns a random bit. + bool RandomBool(); + struct WindowedSpatialDims { Padding padding; std::vector kernel_dims; @@ -453,6 +458,11 @@ std::vector OpTest::RandomDims(int min_rank, int max_rank, return dims; } +bool OpTest::RandomBool() { + std::bernoulli_distribution d(0.5); + return d(generator()); +} + Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, absl::Span shape) { Tensor tensor(dtype, TensorShape(shape)); @@ -760,8 +770,22 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { for (int i = 0; i < Tx.size(); ++i) { if (Tx(i) != Ty(i)) { return errors::InvalidArgument(absl::StrCat( - i, "-th tensor element isn't equal: ", Tx(i), " vs. ", Ty(i), - ". x = ", x.DebugString(), "y = ", y.DebugString())); + i, "-th tensor element isn't equal: ", Str(Tx(i)), " vs. ", + Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString())); + } + } + return Status::OK(); +} + +Status TensorsAreEqualImplBfloat16(const Tensor& x, const Tensor& y) { + auto Tx = x.flat(); + auto Ty = y.flat(); + for (int i = 0; i < Tx.size(); ++i) { + if (Tx(i) != Ty(i)) { + return errors::InvalidArgument(absl::StrCat( + i, "-th tensor element isn't equal: ", static_cast(Tx(i)), + " vs. ", static_cast(Ty(i)), ". x = ", x.DebugString(), + "y = ", y.DebugString())); } } return Status::OK(); @@ -797,6 +821,8 @@ Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, return TensorsAreEqualImpl(a, b); case DT_BOOL: return TensorsAreEqualImpl(a, b); + case DT_BFLOAT16: + return TensorsAreEqualImplBfloat16(a, b); default: LOG(FATAL) << "Unexpected type : " << DataTypeString(a.dtype()); } @@ -829,8 +855,8 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( VLOG(1) << "Input: " << input_tensors.back().DebugString(); } - string cpu_device = - LocalDeviceToFullDeviceName(absl::StrCat(DEVICE_CPU, ":0")); + string reference_device = + LocalDeviceToFullDeviceName(*tf_xla_reference_device_ptr); string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); DeviceNameUtils::ParsedName parsed_name; @@ -845,9 +871,9 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( std::vector expected_inputs, test_inputs; std::vector expected_fetches, test_fetches; Status status = builder.BuildGraph( - absl::StrCat("test", num_tests_, "_expected"), cpu_device, - /* use_jit= */ false, &graph, /* test_node_def= */ nullptr, - &expected_inputs, &expected_fetches); + absl::StrCat("test", num_tests_, "_expected"), reference_device, + /*use_jit=*/false, &graph, /*test_node_def=*/nullptr, &expected_inputs, + &expected_fetches); if (!status.ok()) { LOG(ERROR) << "Expected graph construction failed: " << status; return kFatalError; @@ -1371,6 +1397,19 @@ TEST_F(OpTest, Cast) { }); } +TEST_F(OpTest, CastBF16) { + Repeatedly([this]() { + DataType src_type, dst_type; + src_type = Choose({DT_FLOAT}); + dst_type = Choose({DT_BFLOAT16}); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast") + .RandomInput(src_type) + .Attr("SrcT", src_type) + .Attr("DstT", dst_type) + .Attr("Truncate", true)); + }); +} + TEST_F(OpTest, Ceil) { Repeatedly([this]() { return ExpectTfAndXlaOutputsAreClose( @@ -3346,11 +3385,41 @@ TEST_F(OpTest, ZerosLike) { }); } +// Example failing run: +// --tf_xla_reference_device=GPU:0 +// --tf_xla_test_use_jit=true --tf_xla_test_device=GPU:0 +// --tf_xla_test_repetitions=2 +// --gunit_filter='OpTest.FusedBatchNormTraining' +// --tf_xla_random_seed=2838146746 +TEST_F(OpTest, FusedBatchNormTraining) { + bool is_nhwc = RandomBool(); + std::vector x_dims = RandomDims(/*min_rank=*/4, /*max_rank=*/4, + /*min_size=*/5, /*max_size=*/20); + std::vector scale_dims = {x_dims[is_nhwc ? 3 : 1]}; + std::vector offset_dims = {x_dims[is_nhwc ? 3 : 1]}; + std::vector mean_dims = {0}; + std::vector variance_dims = {0}; + DataType type = DT_FLOAT; + Repeatedly([&] { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("FusedBatchNorm") + .RandomInput(type, x_dims) + .RandomInput(type, scale_dims) + .RandomInput(type, offset_dims) + .RandomInput(type, mean_dims) + .RandomInput(type, variance_dims) + .Attr("T", type) + .Attr("data_format", is_nhwc ? "NHWC" : "NCHW") + .Attr("epsilon", static_cast(1.001e-05)) + .Attr("is_training", true)); + }); +} } // anonymous namespace } // namespace tensorflow int main(int argc, char** argv) { tensorflow::tf_xla_test_device_ptr = new tensorflow::string("GPU:0"); + tensorflow::tf_xla_reference_device_ptr = new tensorflow::string("CPU:0"); std::vector flag_list = { tensorflow::Flag( "tf_xla_random_seed", &tensorflow::tf_xla_random_seed, @@ -3366,6 +3435,9 @@ int main(int argc, char** argv) { "Maximum number of elements for random input tensors."), tensorflow::Flag("tf_xla_test_device", tensorflow::tf_xla_test_device_ptr, "Tensorflow device type to use for test"), + tensorflow::Flag("tf_xla_reference_device", + tensorflow::tf_xla_reference_device_ptr, + "Tensorflow device type to use for reference"), tensorflow::Flag("tf_xla_test_use_jit", &tensorflow::tf_xla_test_use_jit, "Use JIT compilation for the operator under test"), }; diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index 693f8513bc54e30060a2e963abd504768535a50a..a9a87b8fb3104f8b9870c41e2aa28b0c48c12921 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -134,6 +134,12 @@ class ScatterNdTest(xla_test.XLATestCase): expected = np.array([0, 11, 0, 10, 9, 0, 0, 12], dtype=np.int32) self.assertAllEqual(expected, self._runScatterNd(indices, updates, [8])) + def testRepeatedIndices(self): + indices = np.array([[0], [1], [0], [1]], dtype=np.int32) + updates = np.array([9, 10, 11, 12], dtype=np.float32) + expected = np.array([20, 22], dtype=np.int32) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2])) + def testSimple2(self): indices = np.array([[1, 0], [1, 1]], dtype=np.int32) updates = np.array([11., 12.], dtype=np.float32) diff --git a/tensorflow/compiler/tests/self_adjoint_eig_op_test.py b/tensorflow/compiler/tests/self_adjoint_eig_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb5c82b22ea1d7400b54045edee0ca0782ce979 --- /dev/null +++ b/tensorflow/compiler/tests/self_adjoint_eig_op_test.py @@ -0,0 +1,62 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.self_adjoint_eig.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.platform import test + + +class SelfAdjointEigOpTest(xla_test.XLATestCase, parameterized.TestCase): + + def _test(self, dtype, shape): + np.random.seed(1) + x_np = np.random.uniform( + low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) + x_np = x_np + np.swapaxes(x_np, -1, -2) + n = shape[-1] + + e_np, _ = np.linalg.eigh(x_np) + with self.cached_session() as sess: + x_tf = array_ops.placeholder(dtype) + with self.test_scope(): + e, v = linalg_ops.self_adjoint_eig(x_tf) + e_val, v_val = sess.run([e, v], feed_dict={x_tf: x_np}) + + v_diff = np.matmul(v_val, np.swapaxes(v_val, -1, -2)) - np.eye(n) + self.assertAlmostEqual(np.mean(v_diff**2), 0.0, delta=1e-6) + self.assertAlmostEqual(np.mean((e_val - e_np)**2), 0.0, delta=1e-6) + + SIZES = [1, 2, 5, 10, 32] + DTYPES = [np.float32] + PARAMS = itertools.product(SIZES, DTYPES) + + @parameterized.parameters(*PARAMS) + def testSelfAdjointEig(self, n, dtype): + for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10): + self._test(dtype, batch_dims + (n, n)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index ee7ca7e6f196e114ff18e2597145e5c198980b08..df5914a518e06e4190c623a14287de8daefebd40 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -167,8 +167,8 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): beta = (b - mu) / sigma z = normal_cdf(beta) - normal_cdf(alpha) - self.assertTrue((y >= a).sum() == n) - self.assertTrue((y <= b).sum() == n) + self.assertEqual((y >= a).sum(), n) + self.assertEqual((y <= b).sum(), n) # For more information on these calculations, see: # Burkardt, John. "The Truncated Normal Distribution". diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index 5c079d595c440cac644f5461154509abe7b1d1ed..a380715301b08ce2186c97b678b7235b9121d178 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -23,24 +23,20 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import list_ops from tensorflow.python.platform import test -def scalar_shape(): - return ops.convert_to_tensor([], dtype=dtypes.int32) - - class ListOpsTest(xla_test.XLATestCase): def testElementShape(self): with self.cached_session() as sess, self.test_scope(): dim = array_ops.placeholder(dtypes.int32) - l = list_ops.tensor_list_reserve( - element_shape=(dim, 15), num_elements=20, - element_dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=(dim, 15), + element_dtype=dtypes.float32, + max_num_elements=20) e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32) e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64) self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15)) @@ -48,25 +44,44 @@ class ListOpsTest(xla_test.XLATestCase): def testPushPop(self): with self.cached_session() as sess, self.test_scope(): - num = array_ops.placeholder(dtypes.int32) - l = list_ops.tensor_list_reserve( - element_shape=(7, 15), num_elements=num, element_dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=(7, 15), + element_dtype=dtypes.float32, + max_num_elements=10) l = list_ops.tensor_list_push_back( l, constant_op.constant(1.0, shape=(7, 15))) l = list_ops.tensor_list_push_back( l, constant_op.constant(2.0, shape=(7, 15))) l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) - self.assertAllEqual(sess.run(e2, {num: 10}), 2.0 * np.ones((7, 15))) - self.assertAllEqual(sess.run(e1, {num: 10}), 1.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e2), 2.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e1), 1.0 * np.ones((7, 15))) + + def testDoNotConstantFoldVariants(self): + with self.cached_session() as sess, self.test_scope(): + val = array_ops.placeholder(dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=(7, 15), + element_dtype=dtypes.float32, + max_num_elements=10) + # Note: Pushing a Placeholder will force the constant folding code + # to build a Const node with a DT_VARIANT output. This tests that XLA + # passes a cf_consider_fn which prevent folding such nodes. + l = list_ops.tensor_list_push_back( + l, array_ops.fill(value=val, dims=(7, 15))) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(2.0, shape=(7, 15))) + l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(sess.run(e2, {val: 1.0}), 2.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e1, {val: 1.0}), 1.0 * np.ones((7, 15))) def testPushPopSeparateLists(self): with self.cached_session() as sess, self.test_scope(): - num = array_ops.placeholder(dtypes.int32) - l = list_ops.tensor_list_reserve( - element_shape=scalar_shape(), - num_elements=num, - element_dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=[], + element_dtype=dtypes.float32, + max_num_elements=20) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0)) @@ -75,22 +90,125 @@ class ListOpsTest(xla_test.XLATestCase): l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) - result = sess.run([e11, [e21, e22], [e31, e32]], {num: 20}) + result = sess.run([e11, [e21, e22], [e31, e32]]) self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]]) - def testEmptyTensorList(self): - dim = 7 + def testEmptyTensorListNoMax(self): with self.cached_session() as sess, self.test_scope(): - p = array_ops.placeholder(dtypes.int32) l = list_ops.empty_tensor_list( - element_shape=(p, 15), element_dtype=dtypes.float32) + element_shape=(7, 15), element_dtype=dtypes.float32) l = list_ops.tensor_list_push_back( - l, constant_op.constant(1.0, shape=(dim, 15))) + l, constant_op.constant(1.0, shape=(7, 15))) _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Use TensorListReserve instead"): - self.assertEqual(sess.run(e, {p: dim}), 1.0 * np.ones((dim, 15))) + "Set the max number of elements"): + self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15))) + def testEmptyTensorListMax(self): + with self.cached_session() as sess, self.test_scope(): + l = list_ops.empty_tensor_list( + element_shape=(10, 15), element_dtype=dtypes.float32, + max_num_elements=2) + l = list_ops.tensor_list_push_back( + l, array_ops.fill(value=3.0, dims=(10, 15))) + _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(sess.run(e), 3.0 * np.ones((10, 15))) + + def testListFromTensor(self): + with self.cached_session(), self.test_scope(): + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor(t, element_shape=[]) + e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e, 1.0) + l, e0 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(e0, 2.0) + l, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(e1, 1.0) + self.assertAllEqual(list_ops.tensor_list_length(l), 0) + + def testGetSet(self): + with self.cached_session(), self.test_scope(): + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor(t, element_shape=[]) + e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e0, 1.0) + l = list_ops.tensor_list_set_item(l, 0, 3.0) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [3.0, 2.0]) + + def testSetDoesNotUpdatePushIndex(self): + with self.cached_session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_shape=[], element_dtype=dtypes.float32, max_num_elements=2) + # SetItem should not change the push index. + l = list_ops.tensor_list_set_item(l, 1, 3.) + l = list_ops.tensor_list_push_back(l, 5.) + l = list_ops.tensor_list_push_back(l, 7.) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [5., 7.]) + + def testGetSetReserved(self): + with self.cached_session(), self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=[], num_elements=2) + e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e0, 0.0) + l = list_ops.tensor_list_set_item(l, 0, 3.0) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [3.0, 0.0]) + + def testSetStackReservedUnknownElementShape(self): + with self.cached_session(), self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=None, num_elements=2) + l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0]) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]]) + + def testPushInEmptyListWithUnknownElementShape(self): + with self.cached_session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, element_shape=None, max_num_elements=2) + l = list_ops.tensor_list_push_back(l, [3.0, 4.0]) + # Pushing an element with a different shape should raise an error. + with self.assertRaisesRegexp(errors.InvalidArgumentError, "Shape"): + l = list_ops.tensor_list_push_back(l, 5.) + self.evaluate( + list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)) + + def testGetSetReservedNonScalar(self): + with self.cached_session() as sess, self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, + element_shape=(7, 15), + num_elements=2) + l = list_ops.tensor_list_set_item( + l, 0, constant_op.constant(1.0, shape=(7, 15))) + e1 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + e2 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32) + self.assertAllEqual(sess.run(e1), np.ones((7, 15))) + self.assertAllEqual(sess.run(e2), np.zeros((7, 15))) + + def testStack(self): + with self.cached_session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, + element_shape=[], + max_num_elements=2) + l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) + e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e, 1.0) + l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t.shape.as_list(), [None]) + self.assertAllEqual(t, [1.0, 2.0]) + + def testStackWithUninitializedTensors(self): + with self.cached_session(), self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=[], num_elements=3) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [0., 0., 0.]) if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 95c9e7ffd4651642781143c2c1940b0e51e1e470..f2e0eac2d99fe3b71ecabd4b9977817c5f9c372c 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -72,6 +72,7 @@ class UnaryOpsTest(xla_test.XLATestCase): output = op(pinp) result = session.run(output, {pinp: inp}) if equality_test is None: + self.assertEqual(output.dtype, expected.dtype) self.assertAllCloseAccordingToType( result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03) else: @@ -260,7 +261,8 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( math_ops.log1p, np.array([[1e-14, 1e-15, 0.6]], dtype=dtype), - expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)), + expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], + dtype=dtype)).astype(dtype), rtol=1e-4, atol=1e-6) @@ -391,6 +393,11 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array( [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.sign, + np.array([[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0]], dtype=dtype), + expected=np.array([[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0]], dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.is_finite, np.array( @@ -647,7 +654,7 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) - ctypes = {np.complex64: np.float32} + ctypes = {np.complex64: np.float32, np.complex128: np.float64} self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[3 - 4j, -1j, np.inf]], dtype=dtype), @@ -705,7 +712,7 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[2, -1]], dtype=dtype), - expected=np.array([[2, 1]], dtype=dtype)) + expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype)) self._assertOpOutputMatchesExpected( math_ops.negative, @@ -743,6 +750,10 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array( [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype), expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool)) + self._assertOpOutputMatchesExpected( + math_ops.sign, + np.array([[np.nan]], dtype=dtype), + expected=np.array([[0.0]], dtype=dtype)) def testLogicalOps(self): self._assertOpOutputMatchesExpected( @@ -760,7 +771,7 @@ class UnaryOpsTest(xla_test.XLATestCase): lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"), np.array( [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32), - expected=np.array([10., 26.], dtype=np.float32)) + expected=np.array([14., 22.], dtype=np.float32)) def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] @@ -811,6 +822,12 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([1, 2, 0], np.int32), expected=np.array([2, 0, 1], dtype=np.int32)) + def testInvertPermutationTwiceIsNoop(self): + self._assertOpOutputMatchesExpected( + lambda x: array_ops.invert_permutation(array_ops.invert_permutation(x)), + np.array([1, 2, 0], np.int32), + expected=np.array([1, 2, 0], dtype=np.int32)) + def testRank(self): rank_op = lambda x: array_ops.rank_internal(x, optimize=False) for dtype in self.numeric_types: @@ -865,6 +882,17 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([[-1], [1], [4]], dtype=dtype), expected=np.int32(3)) + def testSizeWithInt64OutType(self): + + def size_op(x): + return array_ops.size_internal(x, optimize=False, out_type=np.int64) + + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + size_op, + np.array([[-1], [1], [4]], dtype=dtype), + expected=np.int64(3)) + def testUnpack(self): self._assertOpOutputMatchesExpected( array_ops.unstack, @@ -974,7 +1002,7 @@ class UnaryOpsTest(xla_test.XLATestCase): def _assertSoftplusMatchesExpected(self, features, dtype): features = np.array(features, dtype=dtype) zero = np.asarray(0).astype(dtype) - expected = np.logaddexp(zero, features) + expected = np.logaddexp(zero, features).astype(dtype) self._assertOpOutputMatchesExpected( nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index fcd7ac5ba1ca5049246e93e6f5f76746fb28c6b8..18c5870e0decb686f4df1c16bbb4a340c93ad21d 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -485,7 +485,7 @@ class SliceAssignTest(xla_test.XLATestCase): checker2[None] = [6] # new axis def testUninitialized(self): - with self.assertRaisesRegexp(errors.InvalidArgumentError, + with self.assertRaisesRegexp(errors.FailedPreconditionError, "uninitialized variable"): with self.test_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable([1, 2]) diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 4cf88fc523735cc2d22e085afb83790c7ebb48e4..28274ff799de2c85e1e80512cadbe0206cb640a4 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -319,7 +319,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): session.run(output) self.assertRegexpMatches( invalid_arg_error.exception.message, - (r'^start_indices must be a vector with length equal to input rank, ' + (r'start_indices must be a vector with length equal to input rank, ' r'but input rank is 3 and start_indices has shape \[2\].*')) def testDynamicSliceWithIncorrectSizeIndicesShape(self): @@ -332,7 +332,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): session.run(output) self.assertRegexpMatches( invalid_arg_error.exception.message, - (r'^size_indices must be a vector with length equal to input rank, ' + (r'size_indices must be a vector with length equal to input rank, ' r'but input rank is 3 and size_indices has shape \[2\].*')) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..63cad6a159c3a9b0da9e3bb86ff250dd29e45729 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -0,0 +1,445 @@ +# Description: +# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow +# and provide TensorRT operators and converter package. +# APIs are meant to change over time. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_copts", + "tf_cuda_library", + "tf_custom_op_library", + "tf_custom_op_library_additional_deps", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", +) +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load( + "@local_config_tensorrt//:build_defs.bzl", + "if_tensorrt", +) + +tf_cuda_cc_test( + name = "tensorrt_test_cc", + size = "small", + srcs = ["tensorrt_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + "//tensorflow/core:gpu_init", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_custom_op_library( + name = "python/ops/_trt_ops.so", + srcs = [ + "ops/get_serialized_resource_op.cc", + "ops/trt_engine_op.cc", + ], + deps = [ + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +cc_library( + name = "trt_op_kernels", + srcs = [ + "kernels/get_serialized_resource_op.cc", + "kernels/trt_engine_op.cc", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":test_utils", + ":trt_allocator", + ":trt_conversion", + ":trt_logging", + ":trt_plugins", + ":trt_resources", + ":utils", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:stream_executor_headers_lib", + "//tensorflow/core/grappler/costs:graph_properties", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]) + tf_custom_op_library_additional_deps(), + alwayslink = 1, +) + +tf_cuda_cc_test( + name = "get_serialized_resource_op_test", + size = "small", + srcs = ["kernels/get_serialized_resource_op_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + # TODO(laigd): consider splitting get_serialized_resource_op out from + # TF-TRT. + ":trt_op_kernels", + ":trt_op_libs", + ":trt_resources", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + ], +) + +tf_gen_op_libs( + op_lib_names = [ + "trt_engine_op", + "get_serialized_resource_op", + ], +) + +cc_library( + name = "trt_op_libs", + deps = [ + ":get_serialized_resource_op_op_lib", + ":trt_engine_op_op_lib", + ], +) + +tf_cuda_library( + name = "trt_logging", + srcs = ["utils/trt_logger.cc"], + hdrs = ["utils/trt_logger.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_gen_op_wrapper_py( + name = "trt_ops", + deps = [ + ":trt_op_libs", + ], +) + +tf_custom_op_py_library( + name = "trt_ops_loader", + srcs = ["python/ops/trt_ops.py"], + dso = [ + "python/ops/_trt_ops.so", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), + kernels = [ + ":trt_op_kernels", + ":trt_op_libs", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:resources", + ], +) + +tf_cuda_library( + name = "trt_resources", + srcs = [ + "utils/trt_int8_calibrator.cc", + "utils/trt_resources.cc", + ], + hdrs = [ + "utils/trt_int8_calibrator.h", + "utils/trt_lru_cache.h", + "utils/trt_resources.h", + ], + deps = [ + ":trt_allocator", + ":trt_logging", + ":utils", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cuda_library( + name = "trt_allocator", + srcs = ["utils/trt_allocator.cc"], + hdrs = ["utils/trt_allocator.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cc_test( + name = "trt_allocator_test", + size = "small", + srcs = ["utils/trt_allocator_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + ":trt_allocator", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "trt_lru_cache_test", + size = "small", + srcs = ["utils/trt_lru_cache_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + ":trt_resources", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# Library for the node-level conversion portion of TensorRT operation creation +tf_cuda_library( + name = "trt_conversion", + srcs = [ + "convert/convert_graph.cc", + "convert/convert_nodes.cc", + "convert/trt_optimization_pass.cc", + ], + hdrs = [ + "convert/convert_graph.h", + "convert/convert_nodes.h", + "convert/trt_optimization_pass.h", + ], + deps = [ + ":segment", + ":test_utils", + ":trt_allocator", + ":trt_plugins", + ":trt_logging", + ":trt_resources", + ":utils", + "@com_google_absl//absl/strings", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:graph", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:devices", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/optimizers:meta_optimizer", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]) + tf_custom_op_library_additional_deps(), +) + +tf_cuda_cc_test( + name = "convert_graph_test", + size = "medium", + srcs = ["convert/convert_graph_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_conversion", + "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:direct_session", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cuda_cc_test( + name = "convert_nodes_test", + size = "medium", + srcs = ["convert/convert_nodes_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_logging", + ":trt_conversion", + ":trt_plugins", + "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensor_testutil", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:tensorrt", + ]), +) + +# Library for the segmenting portion of TensorRT operation creation +cc_library( + name = "segment", + srcs = ["segment/segment.cc"], + hdrs = [ + "segment/segment.h", + "segment/union_find.h", + ], + copts = tf_copts(), + deps = [ + "//tensorflow/core:graph", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + "@protobuf_archive//:protobuf_headers", + ], +) + +tf_cuda_cc_test( + name = "segment_test", + size = "small", + srcs = ["segment/segment_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":segment", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +# Library for the plugin factory +tf_cuda_library( + name = "trt_plugins", + srcs = [ + "plugin/trt_plugin.cc", + "plugin/trt_plugin_factory.cc", + "plugin/trt_plugin_utils.cc", + ], + hdrs = [ + "plugin/trt_plugin.h", + "plugin/trt_plugin_factory.h", + "plugin/trt_plugin_utils.h", + ], + deps = [ + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cuda_cc_test( + name = "trt_plugin_factory_test", + size = "small", + srcs = ["plugin/trt_plugin_factory_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_plugins", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:tensorrt", + ]), +) + +cc_library( + name = "utils", + srcs = ["convert/utils.cc"], + hdrs = ["convert/utils.h"], + copts = tf_copts(), + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib_proto_parsing", + ], +) + +cc_library( + name = "test_utils", + srcs = ["utils/test_utils.cc"], + hdrs = ["utils/test_utils.h"], + deps = [ + "//tensorflow/core:lib_proto_parsing", + "@com_googlesource_code_re2//:re2", + ], +) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc similarity index 88% rename from tensorflow/contrib/tensorrt/convert/convert_graph.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index ae211a93c3279ff1d6de2f9c9a4b849fc8cd578d..1f3cae3fda0cd7be296882b7b17ea47554edace8 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" #include #include @@ -24,13 +24,13 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" -#include "tensorflow/contrib/tensorrt/segment/segment.h" -#include "tensorflow/contrib/tensorrt/test/utils.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/segment/segment.h" +#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" @@ -63,8 +63,8 @@ limitations under the License. namespace tensorflow { namespace tensorrt { namespace convert { -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; // Returns compiled TRT version information {Maj, Min, Patch} std::vector GetLinkedTensorRTVersion() { @@ -82,63 +82,81 @@ std::vector GetLoadedTensorRTVersion() { } TrtCandidateSelector::TrtCandidateSelector( - const grappler::GraphProperties& graph_properties, int precision_mode) + const grappler::GraphProperties& graph_properties, + TrtPrecisionMode precision_mode) : graph_properties_(graph_properties), precision_mode_(precision_mode) {} Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(laigd): move this set to TrtNodeValidator where it should belong. // LINT.IfChange - static const std::set candidate_ops = { - "Identity", - "Snapshot", - "Const", - "Conv2D", - "MaxPool", - "BiasAdd", - "Relu", - "Sigmoid", - "Tanh", + static const auto* candidate_ops = new std::set{ + "Abs", + "Acos", + "Acosh", "Add", - "Mul", - "Sub", - "Rsqrt", - "Pad", - "Mean", + "Asin", + "Asinh", + "Atan", + "Atanh", "AvgPool", + "BatchMatMul", + "BiasAdd", + "Ceil", "ConcatV2", + "Const", + "Conv2D", + "Conv2DBackpropInput", + "Cos", + "Cosh", "DepthwiseConv2dNative", - "FusedBatchNorm", - "FusedBatchNormV2", "Div", - "RealDiv", - "Rsqrt", - "Reciprocal", "Exp", + "ExpandDims", + "Floor", + "FusedBatchNorm", + "FusedBatchNormV2", + "GatherV2", + "Identity", + "LeakyRelu", "Log", - "Sqrt", - "Abs", - "Neg", - "Transpose", - "Reshape", "MatMul", - "BatchMatMul", - "Softmax", - "Minimum", - "Maximum", - "TopKV2", - "Sum", - "Prod", "Max", + "Maximum", + "MaxPool", + "Mean", "Min", + "Minimum", + "Mul", + "Neg", + "Pad", + "Prod", + "RealDiv", + "Reciprocal", + "Relu", "Relu6", + "Reshape", + "Rsqrt", + "Sigmoid", + "Sin", + "Sinh", + "Slice", + "Snapshot", + "Softmax", + "Sqrt", "Square", - "ExpandDims", "Squeeze", + "StridedSlice", + "Sub", + "Sum", + "Tan", + "Tanh", + "TopKV2", + "Transpose", }; bool is_supported_op_type = - (candidate_ops.count(node->type_string()) || + (candidate_ops->count(node->type_string()) || PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); - static const std::set quantize_ops = { + static const auto* quantize_ops = new std::set{ "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", "FakeQuantWithMinMaxVars", @@ -147,10 +165,11 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // In INT8 mode, we will always apply the quantization ranges provided by // these ops to the relevant tensors. This happens regardless of the value of // use_calibration. - if (precision_mode_ == INT8MODE && quantize_ops.count(node->type_string())) { + if (precision_mode_ == TrtPrecisionMode::INT8 && + quantize_ops->count(node->type_string())) { is_supported_op_type = true; } - // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) + // LINT.ThenChange(//tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc) if (!is_supported_op_type) { return errors::Unimplemented("Op type ", node->type_string(), " is not supported"); @@ -184,60 +203,11 @@ tensorflow::Status BuildNodeMap( } // namespace -// Function to get calibration from ResourceMgr and put them into nodedef. -tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph, - bool is_dyn_op) { - LOG(INFO) << "Starting Calib Conversion"; - infer_graph->CopyFrom(graph_def); - auto trt_rm = TRTResourceManager::instance(); - auto calib_rm = trt_rm->getManager("TRTCalibration"); - int num_nodes = infer_graph->node_size(); - if (!is_dyn_op) { - LOG(WARNING) << "Construction of static int8 engine is not implemented " - "yet!. Dynamic engine will be constructed"; - } - for (int i = 0; i < num_nodes; ++i) { - auto n = infer_graph->mutable_node(i); - if (n->op() == "TRTEngineOp") { - VLOG(1) << "Processing " << n->name(); - const string& container_name = n->attr().at("segment_funcdef_name").s(); - TRTCalibrationResource* cres = nullptr; - auto status = calib_rm->Lookup(container_name, "Calibrator", &cres); - if (!status.ok()) { - LOG(ERROR) << "Could not get Calibration information. Did you run with " - "calibration data?"; - return tensorflow::errors::FailedPrecondition( - "Need to run graph with calibration data first!"); - } - if (cres->calibrator_) { - cres->calibrator_->waitAndSetDone(); - cres->thr_->join(); - const auto& calibration_table = - cres->calibrator_->getCalibrationTableAsString(); - if (!calibration_table.size()) { - LOG(ERROR) << "Calibration table is empty"; - return tensorflow::errors::Unknown( - "Calibration table is missing. This shouldn't have happened!"); - } - n->mutable_attr()->at("calibration_data").set_s(calibration_table); - } else { - LOG(ERROR) << "Can't get TRTCalibrator from resource manager!"; - return tensorflow::errors::Unknown( - "Can't get TRTCalibrator from resource manager!"); - } - cres->Unref(); - TF_RETURN_IF_ERROR(calib_rm->Cleanup(container_name)); - } - } - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, - int precision_mode, int minimum_segment_size, bool is_dyn_op, + TrtPrecisionMode precision_mode, int minimum_segment_size, bool is_dyn_op, int max_cached_engines, std::vector cached_engine_batches, bool use_calibration) { // Create GrapplerItem. @@ -297,7 +267,7 @@ tensorflow::Status ConvertGraphDefToTensorRT( parameters["max_batch_size"].set_i(max_batch_size); parameters["is_dynamic_op"].set_b(is_dyn_op); parameters["max_workspace_size_bytes"].set_i(max_workspace_size_bytes); - TF_RETURN_IF_ERROR(GetPrecisionModeName( + TF_RETURN_IF_ERROR(TrtPrecisionModeToName( precision_mode, parameters["precision_mode"].mutable_s())); parameters["maximum_cached_engines"].set_i(max_cached_engines); if (!cached_engine_batches.empty()) { @@ -322,17 +292,23 @@ tensorflow::Status ConvertGraphDefToTensorRT( return Status::OK(); } +struct EdgePtrCompare { + bool operator()(const tensorflow::Edge* lhs, + const tensorflow::Edge* rhs) const { + return lhs->id() < rhs->id(); + } +}; + // Function to get subsegment information structure. tensorflow::Status GetEngineInfo( const tensorflow::Graph* g, const tensorflow::grappler::GraphProperties& graph_properties, - const std::set& segment_nodes, + const std::set& segment_nodes, const std::unordered_map& node_map, const std::vector& reverse_topo_order, EngineInfo* info) { - std::vector subgraph_node_ids; // Topologically sorted node ids. - std::set subgraph_node_names = segment_nodes; - std::set added_const_node_ids; // Used to prevent double insertion. + std::vector subgraph_nodes; // Topologically sorted nodes. + std::set added_const_nodes; // Used to prevent double insertion. std::set segment_devices; // Map from src_node_name+port to the unique port numbers of the TRT op, where @@ -344,26 +320,45 @@ tensorflow::Status GetEngineInfo( std::unordered_map input_to_engine_port, output_to_engine_port; for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend(); ++it) { - const auto& node_name = (*it)->name(); - if (segment_nodes.count(node_name) == 0) continue; - auto node = *it; + const Node* node = *it; + if (segment_nodes.count(node) == 0) continue; auto node_device = node->requested_device(); if (!node_device.empty()) { - segment_devices.insert(node_device); + // If device is CPU, treat as if no device was assigned. Don't add CPU to + // segment_device because that would cause a segfault in + // GetDeviceAndAllocator. This is because GetDeviceAndAllocator assumes + // any already set device is a GPU. + DeviceNameUtils::ParsedName parsed_name; + DeviceNameUtils::ParseFullName(node_device, &parsed_name); + if (parsed_name.type == "CPU") { + VLOG(1) << "Node " << node->name() << " was assigned to the CPU. " + << "Attempting to place on GPU."; + } else { + segment_devices.insert(node_device); + } } else { if (node->has_assigned_device_name()) { + // It appears that nodes will not have assigned devices at this point in + // execution. segment_devices.insert(node->assigned_device_name()); } else { VLOG(2) << "Node " << node->name() << " neither have requested device nor assigned device"; } } + subgraph_nodes.push_back(node); + const int node_id = node->id(); - subgraph_node_ids.push_back(node_id); - // Create input connections. - for (const auto edge : node->in_edges()) { + const string& node_name = node->name(); + + // Create input connections. Sort edges first to make determnistic since + // in_edges is a set of pointers. + std::vector in_edges(node->in_edges().begin(), + node->in_edges().end()); + std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare()); + for (const auto edge : in_edges) { auto input_node = edge->src(); - if (input_node->IsSource() || segment_nodes.count(input_node->name())) { + if (input_node->IsSource() || segment_nodes.count(input_node)) { continue; } if (edge->IsControlEdge()) { @@ -380,12 +375,11 @@ tensorflow::Status GetEngineInfo( // // Note that the segmenter already ensure that the constant data input // is valid and suppported by the engine. - if (!added_const_node_ids.insert(input_node->id()).second) { + if (!added_const_nodes.insert(input_node).second) { // Already added before. continue; } VLOG(1) << "Adding const node " << input_node->name(); - QCHECK(subgraph_node_names.insert(input_node->name()).second); // Since we already add (duplicate) the const input node to the segment // graphdef, it's now not a data dependency any more, but to make the // dependency correct we still add a control dependency. @@ -409,10 +403,14 @@ tensorflow::Status GetEngineInfo( node_id, edge->dst_input(), /*input_edge=*/true, port); } } - // Create output connections. - for (const auto edge : node->out_edges()) { + // Create output connections. Sort edges first to make determnistic since + // out_edges is a set of pointers. + std::vector out_edges(node->out_edges().begin(), + node->out_edges().end()); + std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare()); + for (const auto edge : out_edges) { auto output_node = edge->dst(); - if (output_node->IsSink() || segment_nodes.count(output_node->name())) { + if (output_node->IsSink() || segment_nodes.count(output_node)) { continue; } if (edge->IsControlEdge()) { @@ -440,12 +438,11 @@ tensorflow::Status GetEngineInfo( } // For each segment node in topological order. // Construct the const nodes first. - subgraph_node_ids.insert(subgraph_node_ids.begin(), - added_const_node_ids.begin(), - added_const_node_ids.end()); + subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(), + added_const_nodes.end()); TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( - g, graph_properties, subgraph_node_names, subgraph_node_ids, - &info->connections, &info->segment_graph_def, &info->engine_name)); + g, graph_properties, subgraph_nodes, &info->connections, + &info->segment_graph_def, &info->engine_name)); // TODO(sami): This should not happen once segmenter is updated. if (segment_devices.size() == 1) { info->device = *segment_devices.begin(); @@ -566,6 +563,18 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } input_shape_protos.at(conn.port_number) = in_shape; input_shapes.at(conn.port_number) = conn.outside_shape; + // Shape must be fully defined (excluding batch dimension) for static + // mode. + if (info.engine_type == EngineInfo::EngineType::TRTStatic) { + for (int i = 1; i < conn.outside_shape.dims(); i++) { + if (conn.outside_shape.dim_size(i) <= 0) { + return tensorflow::errors::Internal( + "Input shapes must be fully defined when in static mode. " + "Please try is_dynamic_op=True (shape was ", + conn.outside_shape.DebugString(), ")"); + } + } + } // Rewrire data input if it's not found in original graph. tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); @@ -597,7 +606,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } const bool calibrate_int8 = - (info.precision_mode == INT8MODE && info.use_calibration); + (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration); // Build the engine and get its serialized representation. string segment_string; if (info.engine_type == EngineInfo::EngineType::TRTStatic || calibrate_int8) { @@ -610,14 +619,15 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, TrtUniquePtrType engine; // TODO(sami): What happens if 1st dim is not batch? TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( - info.segment_graph_def, calibrate_int8 ? FP32MODE : info.precision_mode, + info.segment_graph_def, + calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode, max_batch_size, info.max_workspace_size_bytes, input_shapes, &trt_logger, alloc, /*calibrator=*/nullptr, &engine, info.use_calibration, /*convert_successfully=*/nullptr)); TrtUniquePtrType engine_data(engine->serialize()); - segment_string = - string((const char*)engine_data->data(), engine_data->size()); + segment_string = string(static_cast(engine_data->data()), + engine_data->size()); if (calibrate_int8) { // See above comment about why not putting this inside the 'else' branch. segment_string = info.segment_graph_def.SerializeAsString(); @@ -626,14 +636,8 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, segment_string = info.segment_graph_def.SerializeAsString(); } - // TODO(aaroey): use enum instead, and add a helper method to do the - // conversion. string prec_string; - TF_RETURN_IF_ERROR(GetPrecisionModeName(info.precision_mode, &prec_string)); - if (info.precision_mode == INT8MODE && calibrate_int8 && - !TRTResourceManager::instance()->getManager("TRTCalibration")) { - LOG(ERROR) << "Failed to construct calibration storage"; - } + TF_RETURN_IF_ERROR(TrtPrecisionModeToName(info.precision_mode, &prec_string)); tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp"); if (!info.device.empty()) node_builder.Device(info.device); if (VLOG_IS_ON(1)) { @@ -649,7 +653,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } if (info.engine_type == EngineInfo::EngineType::TRTStatic && - info.cached_engine_batches.size()) { + !info.cached_engine_batches.empty()) { LOG(WARNING) << "Cached engine batches are ignored for static engines"; } tensorflow::NodeDef trt_node; @@ -663,7 +667,6 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, .Attr("serialized_segment", segment_string) .Attr("calibration_data", "") .Attr("max_cached_engines_count", info.maximum_cached_engines) - .Attr("cached_engine_batches", {max_batch_size}) .Attr("workspace_size_bytes", info.max_workspace_size_bytes) .Attr("precision_mode", prec_string) .Attr("use_calibration", info.use_calibration) @@ -815,6 +818,12 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( auto native_segment = fdeflib.add_function(); TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef( sgraph, StrCat(engine_name, "_native_segment"), native_segment)); + // Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on + // a GPU device as expected. Otherwise, some of the tensors of type DT_INT32 + // would be on host if the op generating the tensor has host memory tag set. + (*native_segment + ->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr] + .set_b(true); if (VLOG_IS_ON(7)) { VLOG(7) << engine_name << " Function_Def "; VLOG(7) << native_segment->DebugString(); @@ -936,7 +945,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { continue; } curr_engine.precision_mode = params.precision_mode; - if (params.use_calibration && params.precision_mode != INT8MODE) { + if (params.use_calibration && + params.precision_mode != TrtPrecisionMode::INT8) { return errors::InvalidArgument( "Calibration with FP32 or FP16 is not supported."); } @@ -1005,27 +1015,31 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { cudaSetDevice(cuda_device_id); auto status = CreateTRTNode(engine_segments, i, params.max_batch_size, &graph, alloc.get(), &engine_nodes); - // If status is ok, we successfully added the node to the graph and can - // remove segment ops. Otherwise graph is not modified. + string msg = StrCat("TensorRT node ", engine.engine_name, " added for segment ", i, " consisting of ", converted_segments.at(i).first.size(), " nodes"); if (status.ok()) { LOG(INFO) << msg << " succeeded."; - for (auto node_name : converted_segments.at(i).first) { - graph.RemoveNode(node_map.at(node_name)); - } } else { // Graph is not modified. LOG(WARNING) << msg << " failed: " << status << ". Fallback to TF..."; } if (VLOG_IS_ON(1)) { msg = "Segment consists of nodes: "; - for (const string& node_name : converted_segments.at(i).first) { - StrAppend(&msg, node_name, ", "); + for (const Node* node : converted_segments.at(i).first) { + StrAppend(&msg, node->name(), ", "); } VLOG(1) << msg; } + + // If status is ok, we successfully added the node to the graph and can + // remove segment ops. Otherwise graph is not modified. + if (status.ok()) { + for (const Node* node : converted_segments.at(i).first) { + graph.RemoveNode(const_cast(node)); + } + } } cudaSetDevice(old_cuda_device); graph.ToGraphDef(params.output_graph_def); diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h similarity index 82% rename from tensorflow/contrib/tensorrt/convert/convert_graph.h rename to tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index 1f39f56f6392ba33af3d74fec12c326ed4451cb6..80f68d36a3ab894e97586687ee9ab93dddc73c50 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ #include -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -36,7 +36,7 @@ namespace convert { class TrtCandidateSelector { public: TrtCandidateSelector(const grappler::GraphProperties& graph_properties, - int precision_mode); + TrtPrecisionMode precision_mode); // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added // to TRT subgraph and later converted into TRT engine. @@ -52,7 +52,7 @@ class TrtCandidateSelector { const grappler::GraphProperties& graph_properties_; // Quantization ops are only converted when using quantized precisions. - const int precision_mode_; + const TrtPrecisionMode precision_mode_; }; struct ConversionParams { @@ -61,7 +61,7 @@ struct ConversionParams { max_batch_size(1), max_workspace_size_bytes(1 << 30), output_graph_def(nullptr), - precision_mode(1), + precision_mode(TrtPrecisionMode::FP32), minimum_segment_size(3), graph_properties(nullptr), cluster(nullptr), @@ -74,7 +74,7 @@ struct ConversionParams { size_t max_batch_size; size_t max_workspace_size_bytes; tensorflow::GraphDef* output_graph_def; - int precision_mode; + TrtPrecisionMode precision_mode; int minimum_segment_size; const tensorflow::grappler::GraphProperties* graph_properties; const tensorflow::grappler::Cluster* cluster; @@ -85,12 +85,6 @@ struct ConversionParams { std::vector cached_engine_batches; // list of cached engines }; -// This method extracts calibration information from the resource managers -// and puts them in to engine nodedefs. -tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def, - bool is_dyn_op); - // - max_batch_size: maximum batch size which can be used for inference for // optimization targets inference run with max batch size. // - max_workspace_size_bytes: The upper bound of memory allowance for engine @@ -99,9 +93,10 @@ tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, - int precision_mode = 1, int minimum_segment_size = 3, - bool is_dyn_op = false, int max_cached_engines = 1, - std::vector cached_engine_batches = {}, bool use_calibration = true); + TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32, + int minimum_segment_size = 3, bool is_dyn_op = false, + int max_cached_engines = 1, std::vector cached_engine_batches = {}, + bool use_calibration = true); // Method to call from optimization pass tensorflow::Status ConvertAfterShapes(ConversionParams& params); @@ -123,4 +118,4 @@ std::pair GetDeviceAndAllocator( #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc similarity index 95% rename from tensorflow/contrib/tensorrt/convert/convert_graph_test.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index 2d2bfeb192c1893824c7b30bfad593c62c203392..1a754181debf41865190aa7f9ca6a76efea98181 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" #include #include #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -75,7 +75,7 @@ TEST(TrtCandidateSelector, Basics) { feed, const_1, matmul_attrs); // Unsupported op. - auto unsupported_op = ops::Sin(s.WithOpName("sin"), feed); + auto unsupported_op = ops::Erf(s.WithOpName("sin"), feed); // Incompatible input. auto incompatible_feed = ops::Placeholder(s.WithOpName("feed"), DT_DOUBLE); @@ -98,7 +98,8 @@ TEST(TrtCandidateSelector, Basics) { grappler::GraphProperties graph_properties(item); TF_EXPECT_OK(graph_properties.InferStatically(true)); - for (const int precision_mode : {FP32MODE, INT8MODE}) { + for (const TrtPrecisionMode precision_mode : + {TrtPrecisionMode::FP32, TrtPrecisionMode::INT8}) { TrtCandidateSelector selector(graph_properties, precision_mode); TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node())); ExpectStatus( @@ -107,13 +108,13 @@ TEST(TrtCandidateSelector, Basics) { "transpose_a is not supported for TensorRT FullyConnected " "(op: MatMul), at: incompatible_matmul"); ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()), - error::UNIMPLEMENTED, "Op type Sin is not supported"); + error::UNIMPLEMENTED, "Op type Erf is not supported"); ExpectStatus( selector.IsTensorRTCandidate( matmul_with_incompatible_input.operation.node()), error::INTERNAL, "Failed to convert input with index 0 to a TRT_TensorOrWeights"); - if (precision_mode == INT8MODE) { + if (precision_mode == TrtPrecisionMode::INT8) { TF_EXPECT_OK(selector.IsTensorRTCandidate(quantize.operation.node())); } else { ExpectStatus(selector.IsTensorRTCandidate(quantize.operation.node()), diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc similarity index 73% rename from tensorflow/contrib/tensorrt/convert/convert_nodes.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 777a80bbc4da7a260cf85d0a7bc5ec16f4cd3cab..9a2ac8c3e5f1d149baf5de25c940e24a8acc9125 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include #include @@ -24,11 +24,13 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT @@ -43,6 +45,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/platform/types.h" @@ -80,10 +83,16 @@ namespace tensorrt { const char* const kInputPHName = "TensorRTInputPH_"; const char* const kOutputPHName = "TensorRTOutputPH_"; +bool IsEngineInput(absl::string_view name) { + return absl::StartsWith(name, kInputPHName); +} +bool IsEngineOutput(absl::string_view name) { + return absl::StartsWith(name, kOutputPHName); +} + namespace convert { -using ::tensorflow::str_util::Split; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, nvinfer1::DataType* trt_dtype) { @@ -183,6 +192,15 @@ Status ValidateTensorProperties(const string& producer_node_type, *trt_dims = TensorShapeToTrtDims(shape, /*ignore_first_dim=*/true); *batch_size = shape.dim_size(0); + // Don't convert empty tensors (dim value of 0). + for (int d = 1; d < shape.dims(); ++d) { + if (shape.dim_size(d) == 0) { + return errors::Unimplemented( + "Input tensor with shape ", shape.DebugString(), + " is an empty tensor, which is not supported by TRT"); + } + } + if (validation_only) return Status::OK(); // Following are validations at runtime. @@ -286,8 +304,8 @@ Status Converter::GetTrtBroadcastShape( const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1; auto compute_output_dims = - [max_nb_dims](const TRT_TensorOrWeights& input, int broadcast_num_dims, - int* output_dims_array, nvinfer1::Dims* output_dims) { + [](const TRT_TensorOrWeights& input, int broadcast_num_dims, + int* output_dims_array, nvinfer1::Dims* output_dims) { const nvinfer1::Dims input_dims = input.GetTrtDims(); std::fill(output_dims_array, output_dims_array + max_nb_dims, 1); std::copy(input_dims.d, input_dims.d + input_dims.nbDims, @@ -334,6 +352,67 @@ Status Converter::GetTrtBroadcastShape( return Status::OK(); } +nvinfer1::ITensor* Converter::CreateConstantLayer( + const TRT_ShapedWeights& weights, const nvinfer1::Dims& dims) { + nvinfer1::Weights trt_weights = weights.GetTrtWeights(); + nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights); + if (!layer) return nullptr; + const nvinfer1::DataType trt_dtype = trt_weights.type; + nvinfer1::ITensor* trt_tensor = layer->getOutput(0); + // TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set + // the data type below, it will always be kFLOAT regardless what the data type + // of the weights is. Once NVIDIA fixes this bug, we should remove the data + // type setting logic below and test should still pass. + trt_tensor->setType(trt_dtype); + return trt_tensor; +} + +tensorflow::Status CreateBroadcastableScalarConstant( + OpConverterParams* params, float value, const nvinfer1::Dims& dims, + const nvinfer1::ITensor** tensor) { + // In order to be broadcastable, the number of dims has to match. + nvinfer1::Dims broadcastable_dims(dims); + for (int i = 0; i < broadcastable_dims.nbDims; i++) { + broadcastable_dims.d[i] = 1; + } + TRT_ShapedWeights weights = params->weight_store->GetTempWeights( + tensorflow::DataType::DT_FLOAT, broadcastable_dims); + auto weights_ptr = + static_cast(const_cast(weights.GetValues())); + weights_ptr[0] = value; + *tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims); + TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name()); + params->converter->ProvideQuantizationRange( + const_cast(*tensor), value, value); + return Status::OK(); +} + +// Convert an axis from TF format to TRT format while validating. TF format +// includes the batch dimension, while TRT does not. TF can also use negative +// indices. +// TODO(tmorris): Use this method in more ops. +tensorflow::Status ConvertAxis(int tf_axis, int trt_nb_dims, + absl::string_view node_name, int* trt_axis) { + const int tf_nb_dims = trt_nb_dims + 1; + // Check bounds. + if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) { + return tensorflow::errors::InvalidArgument( + "Axis value of ", tf_axis, " is out of bounds, must be in range [", + -tf_nb_dims, ", ", tf_nb_dims, "), at ", node_name); + } + // Make negative axis positive. + if (tf_axis < 0) tf_axis += tf_nb_dims; + // Don't allow axis to be the batch dimension. + if (tf_axis == 0) { + return tensorflow::errors::Unimplemented( + "TensorRT does not allow manipulation of the batch dimension, at ", + node_name); + } + // Remove batch dimension. + *trt_axis = tf_axis - 1; + return Status::OK(); +} + inline bool DimsEqual(const nvinfer1::Dims& dim_l, const nvinfer1::Dims& dim_r) { if (dim_l.nbDims != dim_r.nbDims) { @@ -347,6 +426,15 @@ inline bool DimsEqual(const nvinfer1::Dims& dim_l, return true; } +bool AllLengthsEqual(const std::vector>& inputs) { + if (inputs.size() == 0) return true; + int length = inputs.at(0).size(); + for (int i = 1; i < inputs.size(); i++) { + if (inputs.at(i).size() != length) return false; + } + return true; +} + inline nvinfer1::Dims GetTrtDimsForTensor(const tensorflow::Tensor& tensor) { nvinfer1::Dims dims; dims.nbDims = tensor.dims(); @@ -484,6 +572,16 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor { float getDynamicRange() const override { return 0; } #endif +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + bool dynamicRangeIsSet() const override { return true; } + + void resetDynamicRange() override {} + + float getDynamicRangeMin() const override { return 0.f; } + + float getDynamicRangeMax() const override { return 0.f; } +#endif + private: nvinfer1::DataType trt_dtype_; nvinfer1::Dims trt_dims_; @@ -632,6 +730,11 @@ bool TFAttrs::get(const string& key) const { return this->at(key)->b(); } +template <> +int TFAttrs::get(const string& key) const { + return this->at(key)->i(); +} + // TODO(jie): reorder4 & reorder2 should be merged? // TODO(aaroey): fix the order of parameters. template @@ -843,7 +946,7 @@ Status TrtNodeValidator::ConvertConstToWeights( } Converter::Converter(nvinfer1::INetworkDefinition* trt_network, - int precision_mode, bool use_calibration) + TrtPrecisionMode precision_mode, bool use_calibration) : trt_network_(trt_network), precision_mode_(precision_mode), use_calibration_(use_calibration) { @@ -870,13 +973,15 @@ Status Converter::ConvertNode(const NodeDef& node_def) { for (size_t i = 0; i < outputs.size(); ++i) { TRT_TensorOrWeights& output = outputs[i]; string output_name = node_def.name(); - if (i != 0) output_name = StrCat(output_name, ":", i); + if (i != 0) absl::StrAppend(&output_name, ":", i); // We need to check the name before setting it. If the input is one of the // engine input, setting the name here will overwrite engine input // bindings which will cause runtime error. + // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer + // in ConvertIdentity. if (output.is_tensor()) { const char* tensor_name = output.tensor()->getName(); - if (!tensorflow::str_util::StartsWith(tensor_name, kInputPHName)) { + if (!IsEngineInput(tensor_name)) { // TRT initializes tensor names as "(Unnamed ITensor* N)". We rename // them to match their corresponding TensorFlow name. // Note: ITensors that we create internally within TF-TRT which are @@ -922,22 +1027,45 @@ Status Converter::AddInputTensor(const string& name, nvinfer1::DataType dtype, } Status Converter::RenameAndMarkOutputTensors( - const std::vector>& output_tensors) { + const std::vector& output_tensors) { for (const auto& output : output_tensors) { TRT_TensorOrWeights tensor_or_weights; - TF_RETURN_IF_ERROR(GetTensorOrWeights(output.first, &tensor_or_weights)); + TF_RETURN_IF_ERROR( + GetTensorOrWeights(output.source_tensor_name, &tensor_or_weights)); if (!tensor_or_weights.is_tensor()) { - return errors::InvalidArgument("Output ", output.first, + return errors::InvalidArgument("Output ", output.source_tensor_name, " is weights not tensor"); } nvinfer1::ITensor* tensor = tensor_or_weights.tensor(); if (tensor == nullptr) { - return errors::NotFound("Output tensor not found: ", output.first); + return errors::NotFound("Output tensor not found: ", + output.source_tensor_name); } - tensor->setName(output.second.c_str()); - VLOG(1) << "Marking output tensor " << output.first << ", as output tensor " - << output.second; + // Check if this tensor has already been marked as an input or output. + // + // ConvertIdentity can cause the same tensor to be repeated in + // output_tensors, which can cause us to overwrite the name of the output + // tensor binding. For example, if we rename OutputPH_0 to OutputPH_1 then + // we won't be able to locate OutputPH_0 during runtime. To fix this, + // duplicate the tensor using no-op shuffle. + // + // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer + // in ConvertIdentity. + if (IsEngineInput(tensor->getName()) || IsEngineOutput(tensor->getName())) { + // Using shuffle layer for identity by not setting reshape or transpose. + nvinfer1::IShuffleLayer* layer = network()->addShuffle(*tensor); + TFTRT_RETURN_ERROR_IF_NULLPTR( + layer, StrCat("Output Copy for ", tensor->getName())); + MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0)); + tensor = layer->getOutput(0); + } + tensor->setName(output.dest_node_name.c_str()); network()->markOutput(*tensor); + // Set type after marking as output. TRT only supports setType for engine + // outputs and inputs (type is inferred otherwise). + tensor->setType(output.trt_dtype); + VLOG(1) << "Marking output TRT tensor " << output.source_tensor_name + << ", which feeds TF node " << output.dest_node_name; } return Status::OK(); } @@ -1081,11 +1209,9 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, *tensor = layer->getOutput(0); } } else { - nvinfer1::IConstantLayer* layer = - this->network()->addConstant(dims, input.weights().GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape"); - *tensor = layer->getOutput(0); - if (precision_mode() == INT8MODE && !use_calibration()) { + *tensor = CreateConstantLayer(input.weights(), dims); + TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, "TF-TRT Internal Reshape"); + if (precision_mode() == TrtPrecisionMode::INT8 && !use_calibration()) { // If we are in int8 mode and not calibrating, we need to explicitly set a // quantization range for the output tensor of the IConstantLayer. Here we // set the range to [min(weights), max(weights)]. @@ -1120,7 +1246,7 @@ void Converter::ProvideQuantizationRange(nvinfer1::ITensor* tensor, } void Converter::MaybeApplyQuantizationRanges() { - if (precision_mode() != INT8MODE) return; + if (precision_mode() != TrtPrecisionMode::INT8) return; // Infer ranges across marked ops. PropagateQuantizationRanges(); @@ -1243,6 +1369,39 @@ Status Converter::GetInputs(const tensorflow::NodeDef& node_def, return tensorflow::Status::OK(); } +// Checks that the number of inputs match, and enforces that the inputs marked +// as true are constant weights. true means that the input must be a weight, +// while false means the input must be a tensor. In the future, false will mean +// the input can be a tensor or weight. +tensorflow::Status CheckInputsWeights( + const OpConverterParams& params, + const std::vector>& inputs_is_weight) { + const auto& inputs = params.inputs; + const auto& node_def = params.node_def; + if (inputs.size() != inputs_is_weight.size()) { + return tensorflow::errors::InvalidArgument( + node_def.op(), " got ", inputs.size(), " inputs but expected ", + inputs_is_weight.size(), ", at ", node_def.name()); + } + for (int i = 0; i < inputs.size(); i++) { + if (inputs_is_weight[i].second && inputs.at(i).is_tensor()) { + return tensorflow::errors::Unimplemented( + "The input \"", inputs_is_weight[i].first, "\" for ", node_def.op(), + " must be a constant, at ", node_def.name()); + } + // TODO(tmorris): Remove this check and provide a method to automatically + // retrive an input as a tensor, converting via CreateConstantLayer if it + // was originally a weight. We will want a caching mechanism to prevent many + // duplicate constants from being created. + if (!inputs_is_weight[i].second && inputs.at(i).is_weights()) { + return tensorflow::errors::Unimplemented( + "The input \"", inputs_is_weight[i].first, "\" for ", node_def.op(), + " must be a tensor, at ", node_def.name()); + } + } + return tensorflow::Status::OK(); +} + TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store, const TRT_ShapedWeights& weights_src) { auto dtype_new = tensorflow::DataType::DT_HALF; @@ -1435,7 +1594,7 @@ Status BinaryTensorOpWeight(OpConverterParams* params, const_cast(tensor), permutation, &tensor)); } - if (params->converter->precision_mode() == FP16MODE) { + if (params->converter->precision_mode() == TrtPrecisionMode::FP16) { weights = ConvertFP32ToFP16(params->weight_store, weights); } @@ -1478,7 +1637,7 @@ Status BinaryTensorOpWeight(OpConverterParams* params, // Because of this issue, fall back to BinaryTensorOpTensor if we are // doing INT8 with no calibration. There is most likely no performance // penalty by falling back here. - if (params->converter->precision_mode() == INT8MODE && + if (params->converter->precision_mode() == TrtPrecisionMode::INT8 && !params->converter->use_calibration()) { return errors::Unimplemented( "Intermediate quantization range cannot be determined without" @@ -1528,80 +1687,126 @@ Status BinaryTensorOpWeight(OpConverterParams* params, return tensorflow::Status::OK(); } -enum class ConvolutionType { DEFAULT, DEPTHWISE_CONV }; - -tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { +tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group, + bool is_conv2d_backprop_input) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + TRT_TensorOrWeights backprop_output_size; + const nvinfer1::ITensor* tensor = nullptr; + if (is_conv2d_backprop_input) { + // In the case when Conv2dBackpropInput is used for conv2d_transpose, these + // inputs correspond to: output size, filter, and input. + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, + {{"input_sizes", true}, {"filter", true}, {"out_backprop", false}})); + backprop_output_size = inputs.at(0); + tensor = inputs.at(2).tensor(); + } else { + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"filter", true}})); + tensor = inputs.at(0).tensor(); + } + TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); + if (weights_rsck.shape_.nbDims != 4) { + return tensorflow::errors::InvalidArgument( + "Conv2D expects kernel of dimension 4, at " + node_def.name()); + } TFAttrs attrs(node_def); - - int h_index = 2; - int w_index = 3; auto data_format = attrs.get("data_format"); - if (data_format == "NHWC") { + int c_index = (data_format == "NHWC") ? 3 : 1; + int h_index = (data_format == "NHWC") ? 1 : 2; + int w_index = (data_format == "NHWC") ? 2 : 3; + auto tf_dilations = attrs.get>("dilations"); + if (tf_dilations.size() != 4) { + return tensorflow::errors::InvalidArgument( + "Convolution dilations field must specify 4 dimensions, at ", + node_def.name()); + } + if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) { + return tensorflow::errors::Unimplemented( + "Dilation rate must be 1 for batch and channel dimensions, at ", + node_def.name()); + } + const nvinfer1::DimsHW dilation(tf_dilations[h_index], tf_dilations[w_index]); + if (is_conv2d_backprop_input && (dilation.d[0] != 1 || dilation.d[1] != 1)) { + return tensorflow::errors::Unimplemented( + "Dilation with Conv2DBackpropInput (conv2d_transpose) is not supported", + ", at ", node_def.name()); + } + + const auto tf_stride = attrs.get>("strides"); + if (tf_stride.size() != 4) { + return tensorflow::errors::InvalidArgument( + "Convolution strides field must specify 4 dimensions, at ", + node_def.name()); + } + if (tf_stride[0] != 1 || tf_stride[c_index] != 1) { + return tensorflow::errors::Unimplemented( + "Stride must be 1 for batch and channel dimensions, at ", + node_def.name()); + } + const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + if (params->validation_only) return tensorflow::Status::OK(); + + // Transpose to NCHW (NCHW is required for IConvLayer). + const bool need_transpose = (data_format == "NHWC"); + if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( const_cast(tensor), {0, 3, 1, 2}, &tensor)); - h_index = 1; - w_index = 2; - // TODO(jie): transpose it } - - // tensor after transpose (NCHW) + // Dimensions of transposed tensor. const auto tensor_dim = tensor->getDimensions(); - int num_groups = group; - if (num_groups == 0) num_groups = tensor_dim.d[0]; // depthwise convolution - VLOG(2) << "groups count: " << num_groups; + // group == 0 signifies that this is a depthwise convolution, so set + // num_groups to size of input's channel dim. For a non-depthwise conv, + // num_groups will be 1. + const int num_groups = (group == 0) ? tensor_dim.d[0] : group; - TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); - VLOG(2) << "weight shape: " << weights_rsck.DebugString(); - if (weights_rsck.shape_.nbDims != 4) { - return tensorflow::errors::Internal( - "Conv2D expects kernel of dimension 4, at: " + node_def.name()); - } - if (params->converter->precision_mode() == FP16MODE) { - weights_rsck = - ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights()); + if (params->converter->precision_mode() == TrtPrecisionMode::FP16) { + weights_rsck = ConvertFP32ToFP16(params->weight_store, weights_rsck); } - + // For conv, TF weights are RSCK, and TRT expects KCRS. + // For backprop, TF weights are RSKC, and TRT expects CKRS. + // Therefore, this reorder will work for both cases. TRT_ShapedWeights weights = params->weight_store->GetTempWeights(weights_rsck); ReorderRSCKToKCRS(weights_rsck, &weights, num_groups); TRT_ShapedWeights biases(weights.type_); - const int noutput = weights.shape_.d[0] * num_groups; + const int output_axis = is_conv2d_backprop_input ? 1 : 0; + const int noutput = weights.shape_.d[output_axis] * num_groups; nvinfer1::DimsHW kernel_size; kernel_size.h() = weights.shape_.d[2]; kernel_size.w() = weights.shape_.d[3]; - VLOG(2) << "RSCK: " << weights.DebugString(); - VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w(); - - // TODO(jie): stride. (NHWC/NCHW) - const auto tf_stride = attrs.get>("strides"); - VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index; - VLOG(2) << "stride: " << tf_stride[0] << tf_stride[1] << tf_stride[2] - << tf_stride[3]; - const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + // Add padding. std::vector> padding; - // TODO(jie): padding. if (attrs.get("padding") == "SAME") { - // This is NCHW tensor with no batch dimension. - // 1 -> h - // 2 -> w - padding = CreateSamePadding( - stride, kernel_size, - {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); + nvinfer1::DimsHW effective_kernel_size = kernel_size; + effective_kernel_size.h() += (kernel_size.h() - 1) * (dilation.h() - 1); + effective_kernel_size.w() += (kernel_size.w() - 1) * (dilation.w() - 1); + std::vector input_dims; + if (is_conv2d_backprop_input) { + // For backprop, calculate padding based on "input_sizes" input, which + // actually corresponds to output size. ("input_sizes" makes sense in the + // context of Conv2DBackpropInput). + // We use h_index and w_index instead of 1 and 2 because we havent + // transposed backprop_output_size along with the input. + auto output_size_weights = static_cast( + const_cast(backprop_output_size.weights().GetValues())); + input_dims = {output_size_weights[h_index], output_size_weights[w_index]}; + } else { + // Use 1 and 2 because tensor_dim has the dimensions of the transposed + // input. + input_dims = {static_cast(tensor_dim.d[1]), + static_cast(tensor_dim.d[2])}; + } + padding = CreateSamePadding(stride, effective_kernel_size, input_dims); } else { padding = {{0, 0}, {0, 0}}; } - if (padding[0].first != padding[0].second || padding[1].first != padding[1].second) { - // TODO(jie): handle asymmetric padding - VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second - << padding[1].first << padding[1].second; - VLOG(2) << "TENSOR before: " << DebugString(tensor->getDimensions()); + // Handle asymmetric padding. auto pad_layer = params->converter->network()->addPadding( *const_cast(tensor), nvinfer1::DimsHW(padding[0].first, padding[1].first), @@ -1611,24 +1816,38 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { const_cast(tensor), pad_layer->getOutput(0)); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); - VLOG(2) << "TENSOR after: " << DebugString(tensor->getDimensions()); } - nvinfer1::IConvolutionLayer* layer = - params->converter->network()->addConvolution( - *const_cast(tensor), noutput, kernel_size, - weights.GetTrtWeights(), biases.GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + // Add convolution. + nvinfer1::ILayer* conv_layer = nullptr; + if (is_conv2d_backprop_input) { + nvinfer1::IDeconvolutionLayer* layer = + params->converter->network()->addDeconvolution( + *const_cast(tensor), noutput, kernel_size, + weights.GetTrtWeights(), biases.GetTrtWeights()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setStride(stride); + layer->setPadding({padding[0].first, padding[1].first}); + layer->setName(node_def.name().c_str()); + layer->setNbGroups(num_groups); + conv_layer = layer; + } else { + nvinfer1::IConvolutionLayer* layer = + params->converter->network()->addConvolution( + *const_cast(tensor), noutput, kernel_size, + weights.GetTrtWeights(), biases.GetTrtWeights()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setStride(stride); + layer->setPadding({padding[0].first, padding[1].first}); + layer->setName(node_def.name().c_str()); + layer->setNbGroups(num_groups); + layer->setDilation(dilation); + conv_layer = layer; + } + const nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0); - layer->setStride(stride); - layer->setPadding({padding[0].first, padding[1].first}); - layer->setName(node_def.name().c_str()); - layer->setNbGroups(num_groups); - const nvinfer1::ITensor* output_tensor = layer->getOutput(0); - VLOG(2) << "TENSOR out: " << DebugString(output_tensor->getDimensions()); - VLOG(2) << "data_format: " << data_format; - if (data_format == "NHWC") { - // TODO(jie): transpose it back! + // Restore transpose. + if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( const_cast(output_tensor), {0, 2, 3, 1}, &output_tensor)); @@ -1638,18 +1857,6 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { return tensorflow::Status::OK(); } -tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, - ConvolutionType type) { - switch (type) { - case ConvolutionType::DEFAULT: - return ConvertConv2DHelper(params, 1); - case ConvolutionType::DEPTHWISE_CONV: - return ConvertConv2DHelper(params, 0); - } - return tensorflow::errors::Unimplemented("unsupported convolution type at, " + - params->node_def.name()); -} - Status BinaryTensorOpTensor(OpConverterParams* params, const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r) { @@ -1677,6 +1884,13 @@ Status BinaryTensorOpTensor(OpConverterParams* params, "Unsupported binary op broadcast scheme for op ", node_def.name(), ": ", status.error_message()); } + TFAttrs attrs(node_def); + nvinfer1::DataType dtype = attrs.get("T"); + if (dtype == nvinfer1::DataType::kINT32) { + return errors::Unimplemented("Binary op ", node_def.op(), + " does not support INT32, at ", + node_def.name()); + } if (params->validation_only) return Status::OK(); const nvinfer1::ITensor* tensor_l = nullptr; @@ -1693,8 +1907,6 @@ Status BinaryTensorOpTensor(OpConverterParams* params, } // Check type consistency. - TFAttrs attrs(node_def); - nvinfer1::DataType dtype = attrs.get("T"); TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype) << DebugString(tensor_l->getType()) << " vs " << DebugString(dtype); TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype) @@ -1754,12 +1966,8 @@ tensorflow::Status ConvertPlugin(OpConverterParams* params) { tensorflow::Status ConvertTranspose(OpConverterParams* params) { const auto& inputs = params->inputs; - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at ", params->node_def.name()); - } - + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"x", false}, {"perm", true}})); // Get the permutation from weights. TRT_ShapedWeights weights = inputs.at(1).weights(); const int* weights_ptr = @@ -1792,11 +2000,8 @@ tensorflow::Status ConvertTranspose(OpConverterParams* params) { tensorflow::Status ConvertReshape(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2 || !inputs.at(1).is_weights()) { - return tensorflow::errors::InvalidArgument( - "Input expects weights for shape, at ", node_def.name()); - } - + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"tensor", false}, {"shape", true}})); TRT_TensorOrWeights input_tensor = inputs.at(0); TRT_ShapedWeights weights = inputs.at(1).weights(); if (weights.count() == 0) { @@ -1892,18 +2097,8 @@ tensorflow::Status ConvertReshape(OpConverterParams* params) { tensorflow::Status ConvertExpandDims(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2) { - return tensorflow::errors::InvalidArgument( - "Two inputs expected for ExpandDims, at ", node_def.name()); - } - if (inputs.at(0).is_weights()) { - return tensorflow::errors::Unimplemented( - "ExpandDims expects tensor for input, at ", node_def.name()); - } - if (!inputs.at(1).is_weights()) { - return tensorflow::errors::InvalidArgument( - "ExpandDims expects weights for axis, at ", node_def.name()); - } + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"axis", true}})); // Get input shape as vector. TRT_TensorOrWeights input_tensor = inputs.at(0); const nvinfer1::Dims dims = input_tensor.GetTrtDims(); @@ -1953,14 +2148,7 @@ tensorflow::Status ConvertExpandDims(OpConverterParams* params) { tensorflow::Status ConvertSqueeze(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 1) { - return tensorflow::errors::InvalidArgument( - "One input expected for Squeeze, at ", node_def.name()); - } - if (inputs.at(0).is_weights()) { - return tensorflow::errors::Unimplemented( - "Squeeze expects tensor for input, at ", node_def.name()); - } + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); // Get input shape. TRT_TensorOrWeights input_tensor = inputs.at(0); const nvinfer1::Dims dims = input_tensor.GetTrtDims(); @@ -1971,7 +2159,7 @@ tensorflow::Status ConvertSqueeze(OpConverterParams* params) { // Mark axes to remove by setting them to 0. TFAttrs attrs(node_def); auto squeeze_dims = attrs.get>("squeeze_dims"); - if (squeeze_dims.size() == 0) { + if (squeeze_dims.empty()) { return tensorflow::errors::Unimplemented( "Squeeze is only implemented for explicit dims, at ", node_def.name()); } @@ -2016,20 +2204,371 @@ tensorflow::Status ConvertSqueeze(OpConverterParams* params) { return tensorflow::Status::OK(); } +tensorflow::Status ConvertStridedSliceHelper(OpConverterParams* params, + const TRT_TensorOrWeights& input, + std::vector begin, + std::vector size, + const std::vector& stride) { + const auto& node_def = params->node_def; + // Get input dims. + nvinfer1::Dims dims = input.GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Temporarily add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), -1); + // Check bounds. + for (int i = 1; i < input_dims.size(); i++) { + if (begin[i] < 0 || begin[i] > input_dims[i]) { + return tensorflow::errors::InvalidArgument( + "\"begin\" for dimension ", std::to_string(i), " in ", node_def.op(), + " is out of range, at ", node_def.name()); + } + const int end = begin[i] + size[i]; + if (end < 0 || end > input_dims[i]) { + return tensorflow::errors::InvalidArgument( + "\"begin\" + \"size\" for dimension ", std::to_string(i), " in ", + node_def.op(), " is out of range, at ", node_def.name()); + } + if (size[i] <= 0) { + return tensorflow::errors::InvalidArgument( + "\"size\" cannot be negative or zero for ", node_def.op(), ", at ", + node_def.name()); + } + } +// TRT 5.1 adds a slice layer. For older versions, we attempt to use the +// padding layer with negative padding. +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + // Use ISliceLayer. + nvinfer1::Dims begin_dims, size_dims, stride_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(begin, &begin_dims, + /*ignore_first_dim=*/true)); + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(size, &size_dims, + /*ignore_first_dim=*/true)); + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(stride, &stride_dims, + /*ignore_first_dim=*/true)); + if (params->validation_only) return Status::OK(); + + nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice( + *const_cast(input.tensor()), begin_dims, size_dims, + stride_dims); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return tensorflow::Status::OK(); +#else + // Use IPaddingLayer. + // Strides must be 1 in this case. + for (int x : stride) { + if (x != 1) { + return tensorflow::errors::Unimplemented( + "Strides other than 1 are not supported with this version of TRT, " + "at ", + node_def.name()); + } + } + // Rank must be 2, 3 or 4. + if (input_dims.size() > 4) { + return tensorflow::errors::Unimplemented(node_def.op(), + " for tensors with rank > 4 is " + "not supported in this version of " + "TRT, at ", + node_def.name()); + } + // Reshape if necessary to 4-D, since IPaddingLayer requires a 4-D input. + const bool need_reshape = (input_dims.size() != 4); + int reshape_dims_added = 0; + nvinfer1::Dims reshape_dims; + if (need_reshape) { + // Add new dims after batch dim until tensor is 4D. + while (input_dims.size() < 4) { + input_dims.insert(input_dims.begin() + 1, 1); + begin.insert(begin.begin() + 1, 0); + size.insert(size.begin() + 1, 1); + reshape_dims_added++; + } + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &reshape_dims, + /*ignore_first_dim=*/true)); + } + // Find dimensions which need to be sliced. + std::vector pad_dims; + for (int i = 1; i < input_dims.size(); i++) { + if ((begin[i] != 0) || (begin[i] + size[i] != input_dims[i])) { + pad_dims.push_back(i); + } + } + if (pad_dims.empty()) { + // No dimensions are changed, so this is a no-op. We could just return the + // input without creating a new layer. TRT will crash if an empty engine + // with no layers is attempted to be created, so we add a no-op shuffle to + // prevent our unit tests from breaking. + // TODO(tmorris): Allow empty engines in the unit tests and return the input + // as output here. + if (params->validation_only) return Status::OK(); + nvinfer1::IShuffleLayer* layer = params->converter->network()->addShuffle( + *const_cast(input.tensor())); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return tensorflow::Status::OK(); + } else if (pad_dims.size() == 1) { + // Only one dim is modified but we have to have 2, mark a second dim which + // will have padding of 0. The dim we add is chosen to avoid an unecessary + // transpose. + if (pad_dims[0] != 2) { + pad_dims.push_back(2); + } else { + pad_dims.push_back(3); + } + } else if (pad_dims.size() > 2) { + return tensorflow::errors::Unimplemented( + node_def.op(), + " can only modify up to 2 dimensions in this version of TRT, at ", + node_def.name()); + } + std::sort(pad_dims.begin(), pad_dims.end()); + // Convert to pre/post padding values. Since TRT does not have a StridedSlice + // or Slice layer prior to 5.1, we instead create an IPaddingLayer with + // negative padding. + nvinfer1::DimsHW pre_padding, post_padding; + for (int i = 0; i < pad_dims.size(); i++) { + const int axis = pad_dims[i]; + pre_padding.d[i] = -begin[axis]; + post_padding.d[i] = (begin[axis] + size[axis]) - input_dims[axis]; + } + + // IPaddingLayer will always apply the padding to dims 2,3 (input format is + // NCHW). + const bool need_transpose = !(pad_dims[0] == 2 && pad_dims[1] == 3); + std::vector transpose_order(input_dims.size()); + std::vector inv_transpose_order(input_dims.size()); + if (need_transpose) { + if (pad_dims[0] == 1 && pad_dims[1] == 3) { + transpose_order = {0, 2, 1, 3}; + inv_transpose_order = {0, 2, 1, 3}; + } else if (pad_dims[0] == 1 && pad_dims[1] == 2) { + transpose_order = {0, 3, 1, 2}; + inv_transpose_order = {0, 2, 3, 1}; + } + } + if (params->validation_only) return Status::OK(); + + // Start conversion. + nvinfer1::ITensor* tensor = const_cast(input.tensor()); + if (need_reshape) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + input, reshape_dims, &output_tensor)); + tensor = const_cast(output_tensor); + } + if (need_transpose) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + tensor, transpose_order, &output_tensor)); + tensor = const_cast(output_tensor); + } + // Add padding layer + nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( + *const_cast(tensor), pre_padding, post_padding); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + params->converter->MarkQuantizationRangesAsInferrable(tensor, + layer->getOutput(0)); + tensor = layer->getOutput(0); + // Restore transpose + if (need_transpose) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + tensor, inv_transpose_order, &output_tensor)); + tensor = const_cast(output_tensor); + } + // Restore reshape + if (need_reshape) { + // Calculate output dimensions + for (int i = 0; i < pad_dims.size(); i++) { + const int axis = pad_dims[i]; + input_dims[axis] = size[axis]; + } + // Remove added 1 dimensions + for (int i = 0; i < reshape_dims_added; i++) { + int value = input_dims[1]; + if (value != 1) { + return tensorflow::errors::Internal( + "StridedSlice error when reshaping, at ", node_def.name()); + } + input_dims.erase(input_dims.begin() + 1); + } + + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + TRT_TensorOrWeights(tensor), new_dims, &output_tensor)); + tensor = const_cast(output_tensor); + } + + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(tensor))); + return tensorflow::Status::OK(); +#endif +} + +tensorflow::Status ConvertSlice(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"input", false}, {"begin", true}, {"size", true}})); + std::vector begin = inputs.at(1).weights().ToVector(); + std::vector size = inputs.at(2).weights().ToVector(); + // Get input dims. + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); + if (!AllLengthsEqual({input_dims, begin, size})) { + return tensorflow::errors::InvalidArgument( + "Length of begin and size arguments must equal rank of input for " + "Slice, at ", + node_def.name()); + } + // Check that batch dimension is unmodified. + const bool begin_is_modified = begin[0] != 0; + // If size[0]s is not -1, we can only know if the batch dimension is + // unmodified when the batch size is defined. When the batch size is + // undefined, we don't convert to be safe. + const bool batch_size_is_defined = input_dims[0] > 0; + const bool size_is_modified = + size[0] != -1 && (!batch_size_is_defined || + (batch_size_is_defined && size[0] != input_dims[0])); + if (begin_is_modified || size_is_modified) { + return tensorflow::errors::Unimplemented( + "TensorRT does not allow modifications to the batch dimension, at ", + node_def.name()); + } + // Size of -1 signifies to take all remaining elements. + for (int i = 1; i < input_dims.size(); i++) { + if (size[i] == -1) { + size[i] = input_dims[i] - begin[i]; + } + } + // Stride is 1 for all dims. + std::vector stride(begin.size(), 1); + return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, stride); +} + +tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, + {{"input", false}, {"begin", true}, {"end", true}, {"strides", true}})); + // Get input dims. + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); + // Get begin and end bounds per axis. + std::vector begin = inputs.at(1).weights().ToVector(); + std::vector end = inputs.at(2).weights().ToVector(); + std::vector stride = inputs.at(3).weights().ToVector(); + if (!AllLengthsEqual({input_dims, begin, end, stride})) { + return tensorflow::errors::InvalidArgument( + "Length of begin, end, and stride arguments must equal rank of input " + "for StridedSlice, at ", + node_def.name()); + } + // Unsupported mask options. + TFAttrs attrs(node_def); + for (const string& attr : + {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) { + int attr_val = attrs.get(attr); + if (attr_val != 0) { + return tensorflow::errors::Unimplemented( + attr, " is not supported for StridedSlice, at ", node_def.name()); + } + } + const int begin_mask = attrs.get("begin_mask"); + const int end_mask = attrs.get("end_mask"); + // Check that batch dimension is unmodified. + const bool begin_is_modified = !(begin_mask & 1) && begin[0] != 0; + const bool stride_is_modified = stride[0] != 1; + // If the batch size is -1 and the end mask is not set, we can only know if + // the batch dimension is unmodified when the batch size is defined. When the + // batch size is undefined, we don't convert to be safe. + const bool batch_size_is_defined = input_dims[0] > 0; + const bool end_is_modified = + !(end_mask & 1) && (!batch_size_is_defined || + (batch_size_is_defined && end[0] != input_dims[0])); + if (begin_is_modified || stride_is_modified || end_is_modified) { + return tensorflow::errors::Unimplemented( + "TensorRT does not allow modifications to the batch dimension, at ", + node_def.name()); + } + // Standarize begin and end bounds by applying masks, making negative values + // positive, and correcting out of bounds ranges (StridedSlice does this + // silently). + for (int i = 1; i < input_dims.size(); i++) { + // Begin + if ((1 << i) & begin_mask) { + begin[i] = 0; + } else if (begin[i] < 0) { + begin[i] += input_dims[i]; + } + begin[i] = std::max(0, std::min(begin[i], input_dims[i])); + // End + if ((1 << i) & end_mask) { + end[i] = input_dims[i]; + } else if (end[i] < 0) { + end[i] += input_dims[i]; + } + end[i] = std::max(0, std::min(end[i], input_dims[i])); + } + // Negative or zero strides currently not supported. + for (int i = 0; i < input_dims.size(); i++) { + if (stride[i] <= 0) { + return tensorflow::errors::Unimplemented( + "Negative or zero stride values are not supported for StridedSlice, " + "at ", + node_def.name()); + } + } + // TRT Slice layer uses (begin, size) instead of (begin, end) + std::vector size(input_dims.size()); + for (int i = 0; i < input_dims.size(); i++) { + // Divide by stride (round up) + size[i] = (end[i] - begin[i] + stride[i] - 1) / stride[i]; + } + return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, stride); +} + tensorflow::Status ConvertConv2D(OpConverterParams* params) { - return ConvertConv2DHelper(params, ConvolutionType::DEFAULT); + return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/false); } tensorflow::Status ConvertConv2DDepthwise(OpConverterParams* params) { - return ConvertConv2DHelper(params, ConvolutionType::DEPTHWISE_CONV); + return ConvertConv2DHelper(params, 0, /*is_conv2d_backprop_input=*/false); +} + +tensorflow::Status ConvertConv2DBackpropInput(OpConverterParams* params) { + return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/true); } tensorflow::Status ConvertPool(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); + nvinfer1::PoolingType type; + if (node_def.op() == "MaxPool") { + type = nvinfer1::PoolingType::kMAX; + } else if (node_def.op() == "AvgPool") { + type = nvinfer1::PoolingType::kAVERAGE; + } else { + return tensorflow::errors::Unimplemented( + "Unsupported pooling type: ", node_def.op(), ", at ", node_def.name()); + } TFAttrs attrs(node_def); + const string padding_type = attrs.get("padding"); + if ((padding_type != "SAME") && (padding_type != "VALID")) { + return tensorflow::errors::Unimplemented( + "Unsupported padding type: ", padding_type, ", at ", node_def.name()); + } + if (params->validation_only) return Status::OK(); + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); int h_index = 2; int w_index = 3; const auto data_format = attrs.get("data_format"); @@ -2040,16 +2579,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { const_cast(tensor), {0, 3, 1, 2}, &tensor)); } - nvinfer1::PoolingType type; - if (node_def.op() == "MaxPool") { - type = nvinfer1::PoolingType::kMAX; - } else if (node_def.op() == "AvgPool") { - type = nvinfer1::PoolingType::kAVERAGE; - } else { - return tensorflow::errors::Unimplemented("Unsupported pool type: ", - node_def.op()); - } - const auto tf_stride = attrs.get>("strides"); const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); @@ -2058,7 +2587,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { auto tensor_dim = tensor->getDimensions(); std::vector> padding; - const string padding_type = attrs.get("padding"); if (padding_type == "SAME") { // This is NCHW tensor with no batch dimension. // 1 -> h @@ -2068,9 +2596,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); } else if (padding_type == "VALID") { padding = {{0, 0}, {0, 0}}; - } else { - return tensorflow::errors::Unimplemented("Unsupported padding type: ", - padding_type); } if (padding[0].first != padding[0].second || @@ -2112,7 +2637,9 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { return tensorflow::Status::OK(); } -tensorflow::Status ConvertActivation(OpConverterParams* params) { +// TODO(tmorris): Use ActivationType::kLEAKY_RELU in TRT 5.1+ once perf +// improves. +tensorflow::Status ConvertLeakyRelu(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (inputs.size() != 1) { @@ -2124,6 +2651,47 @@ tensorflow::Status ConvertActivation(OpConverterParams* params) { node_def.op(), " is only implemented for tensors, at ", node_def.name()); } + TFAttrs attrs(node_def); + const float alpha = attrs.get("alpha"); + if (alpha < 0.0f || alpha > 1.0f) { + return tensorflow::errors::Unimplemented( + "Alpha value for LeakyRelu must be between 0 and 1, at ", + node_def.name()); + } + if (params->validation_only) return tensorflow::Status::OK(); + + // Input Tensor + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + // Create const for alpha. + const nvinfer1::ITensor* const_alpha_tensor = nullptr; + TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant( + params, alpha, tensor->getDimensions(), &const_alpha_tensor)); + // alpha * x + nvinfer1::IElementWiseLayer* mul_layer = + params->converter->network()->addElementWise( + *const_cast(tensor), + *const_cast(const_alpha_tensor), + nvinfer1::ElementWiseOperation::kPROD); + TFTRT_RETURN_ERROR_IF_NULLPTR(mul_layer, node_def.name()); + // max(x, alpha * x) + nvinfer1::IElementWiseLayer* max_layer = + params->converter->network()->addElementWise( + *const_cast(tensor), + *const_cast(mul_layer->getOutput(0)), + nvinfer1::ElementWiseOperation::kMAX); + TFTRT_RETURN_ERROR_IF_NULLPTR(max_layer, node_def.name()); + nvinfer1::ITensor* output_tensor = max_layer->getOutput(0); + params->converter->MarkQuantizationRangesAsInferrable( + output_tensor, const_cast(mul_layer->getOutput(0))); + + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return Status::OK(); +} + +tensorflow::Status ConvertActivation(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); static const std::unordered_map ops{ {"Relu", nvinfer1::ActivationType::kRELU}, {"Sigmoid", nvinfer1::ActivationType::kSIGMOID}, @@ -2157,19 +2725,19 @@ tensorflow::Status ConvertActivation(OpConverterParams* params) { Status ConvertQuantize(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if ((inputs.size() == 0) || - (node_def.op() == "FakeQuantWithMinMaxArgs" && inputs.size() != 1) || - (node_def.op() == "FakeQuantWithMinMaxVars" && inputs.size() != 3) || - (node_def.op() == "QuantizeAndDequantizeV2" && inputs.size() != 3) || - (node_def.op() == "QuantizeAndDequantizeV3" && inputs.size() != 4)) { - return errors::InvalidArgument("Invalid number of inputs for ", - node_def.op(), ", at ", node_def.name()); - } - if (inputs.at(0).is_weights()) { - // TensorRT will automatically quantize weights, so we will ignore ranges - // for weights. - params->outputs->push_back(inputs.at(0)); - return Status::OK(); + if (node_def.op() == "FakeQuantWithMinMaxArgs") { + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); + } else if (node_def.op() == "FakeQuantWithMinMaxVars") { + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"input", false}, {"min", true}, {"max", true}})); + } else if (node_def.op() == "QuantizeAndDequantizeV2") { + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"input", false}, {"input_min", true}, {"input_max", true}})); + } else if (node_def.op() == "QuantizeAndDequantizeV3") { + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}, + {"input_min", true}, + {"input_max", true}, + {"num_bits", true}})); } float min_range = 0.0f; float max_range = 0.0f; @@ -2186,11 +2754,6 @@ Status ConvertQuantize(OpConverterParams* params) { node_def.op() == "QuantizeAndDequantizeV2" || node_def.op() == "QuantizeAndDequantizeV3") { // Get ranges via inputs. - if (!inputs.at(1).is_weights() || !inputs.at(2).is_weights()) { - return errors::InvalidArgument("Min and max inputs for ", node_def.op(), - " must be weights not tensors, at ", - node_def.name()); - } auto get_weights_value = [&inputs](int index) { auto raw_weights = static_cast( const_cast(inputs.at(index).weights().GetValues())); @@ -2221,20 +2784,11 @@ Status ConvertQuantize(OpConverterParams* params) { return Status::OK(); } -// TODO(pdavoodi): we should update relu6 implementation once TensorRT supports -// Relu6 natively. +// TODO(tmorris): Use ActivationType::kCLIP in TRT 5.1+ once perf improves. tensorflow::Status ConvertRelu6(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 1) { - return tensorflow::errors::InvalidArgument( - "Invalid number of inputs for Relu6, at ", node_def.name()); - } - if (inputs.at(0).is_weights()) { - return tensorflow::errors::Unimplemented( - "Relu6 is only implemented for tensors, not weights, at ", - node_def.name()); - } + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); if (params->validation_only) return Status::OK(); // *************************************************************************** // TensorRT does not implement Relu6 natively. This function converts Relu6 op @@ -2258,35 +2812,18 @@ tensorflow::Status ConvertRelu6(OpConverterParams* params) { params->converter->ProvideQuantizationRange(relu_layer->getOutput(0), 0.0f, 6.0f); - // Create a constant layer to store the floating point weight i.e. 6.0f This - // tensor will be broadcasted uniformly during elementwise `min` operation. - // The constant has to have the same rank as the input in order for TRT to - // broadcast - nvinfer1::Dims dims; - dims.nbDims = relu_layer->getOutput(0)->getDimensions().nbDims; - for (int i = 0; i < dims.nbDims; i++) { - dims.d[i] = 1; - } - TRT_ShapedWeights weights = params->weight_store->GetTempWeights( - tensorflow::DataType::DT_FLOAT, dims); - auto weights_ptr = - static_cast(const_cast(weights.GetValues())); - weights_ptr[0] = 6.0f; - nvinfer1::IConstantLayer* const6_layer = - params->converter->network()->addConstant(dims, weights.GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(const6_layer, node_def.name()); - params->converter->ProvideQuantizationRange(const6_layer->getOutput(0), 0.0f, - 6.0f); + // Create a constant layer to store the floating point weight i.e. 6.0f + const nvinfer1::ITensor* const6_tensor = nullptr; + TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant( + params, 6.0f, relu_layer->getOutput(0)->getDimensions(), &const6_tensor)); // ElementWise Min Operation // Min op is a nop for INT8 execution path, as the input tensor // to this layer will only have values in range [0.f, 6.0f]. - const nvinfer1::ITensor* tensor_l = relu_layer->getOutput(0); - const nvinfer1::ITensor* tensor_r = const6_layer->getOutput(0); nvinfer1::IElementWiseLayer* relu6_layer = params->converter->network()->addElementWise( - *const_cast(tensor_l), - *const_cast(tensor_r), + *const_cast(relu_layer->getOutput(0)), + *const_cast(const6_tensor), nvinfer1::ElementWiseOperation::kMIN); TFTRT_RETURN_ERROR_IF_NULLPTR(relu6_layer, node_def.name()); nvinfer1::ITensor* output_tensor = relu6_layer->getOutput(0); @@ -2299,17 +2836,20 @@ tensorflow::Status ConvertRelu6(OpConverterParams* params) { tensorflow::Status ConvertBiasAdd(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return errors::InvalidArgument("Input expects tensor and weights, at ", - node_def.name()); + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"value", false}, {"bias", true}})); + TFAttrs attrs(node_def); + tensorflow::DataType tf_dtype = attrs.get("T"); + if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) { + return errors::Unimplemented("Data type is not supported, for node ", + node_def.name(), " got ", + DataTypeString(tf_dtype)); } if (params->validation_only) return Status::OK(); nvinfer1::ITensor* tensor = const_cast(inputs.at(0).tensor()); const nvinfer1::Dims original_dims = tensor->getDimensions(); - TFAttrs attrs(node_def); const string data_format = attrs.get("data_format"); const int channel_index = (data_format == "NHWC" ? original_dims.nbDims - 1 : 0); @@ -2355,7 +2895,7 @@ tensorflow::Status ConvertBiasAdd(OpConverterParams* params) { } TRT_ShapedWeights weights = inputs.at(1).weights(); - if (params->converter->precision_mode() == FP16MODE) { + if (params->converter->precision_mode() == TrtPrecisionMode::FP16) { weights = ConvertFP32ToFP16(params->weight_store, weights); } nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL; @@ -2399,43 +2939,69 @@ tensorflow::Status ConvertBiasAdd(OpConverterParams* params) { return Status::OK(); } -Status GetTensorDimsWithProtoShape(const Tensor& tensor, - int tensor_proto_array_len, - nvinfer1::Dims* dims) { +void GetTensorDimsWithProtoShape(const Tensor& tensor, nvinfer1::Dims* dims) { if (tensor.dims() > 0) { *dims = GetTrtDimsForTensor(tensor); - if (TrtDimsNumElements(*dims) != tensor_proto_array_len && - tensor_proto_array_len != 1) { - return errors::InvalidArgument( - "Broadcast on weights only supports kCHANNEL and kUNIFORM"); - } } else { dims->nbDims = 1; // No dimension provided. Flatten it. - dims->d[0] = tensor_proto_array_len; + dims->d[0] = tensor.NumElements(); dims->type[0] = nvinfer1::DimensionType::kSPATIAL; for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; ++i) { dims->d[i] = 0; } } - return Status::OK(); } -template -Status TfTensorToTrtWeights(const DataType dtype, const Tensor& tensor, - const CType* tensor_proto_array, - int tensor_proto_array_len, TrtWeightStore* store, +Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store, TRT_ShapedWeights* weights) { + const DataType dtype = tensor.dtype(); + + // We always convert the integer constants to INT32, since TRT INT8 is for + // quantized inference. + // + // TODO(aaroey): FP16 will remain in half format and is not converted to + // FP32, but the converter currently uses all float weights as FP32. Fix + // this. + const DataType converted_dtype = + (dtype == DT_INT16 || dtype == DT_INT8 || dtype == DT_UINT8 ? DT_INT32 + : dtype); + + // Verify that the dtype is supported by TensorRT. Otherwise, return an error. + nvinfer1::DataType trt_dtype; + TF_RETURN_IF_ERROR(ConvertDType(converted_dtype, &trt_dtype)); + + if (tensor.NumElements() == 0) { + // Return empty weights having converted dtype. + *weights = TRT_ShapedWeights(converted_dtype); + return Status::OK(); + } + nvinfer1::Dims weight_dims; - TF_RETURN_IF_ERROR(GetTensorDimsWithProtoShape(tensor, tensor_proto_array_len, - &weight_dims)); - *weights = store->GetTempWeights(dtype, weight_dims); - void* dst = const_cast(weights->GetValues()); - if (tensor_proto_array_len == 1) { - std::fill_n((CType*)dst, TrtDimsNumElements(weight_dims), - *tensor_proto_array); + GetTensorDimsWithProtoShape(tensor, &weight_dims); + *weights = weight_store->GetTempWeights(converted_dtype, weight_dims); + + // Copy the tensor directly if the tensor does not require cast to the + // supported type. + if (converted_dtype == dtype) { + char* dst = static_cast(const_cast(weights->GetValues())); + memcpy(dst, tensor.tensor_data().data(), tensor.TotalBytes()); + return Status::OK(); + } + + // Copy tensor elements after casting them to the converted DataType. + int32* dst = static_cast(const_cast(weights->GetValues())); + if (dtype == DT_INT16) { + const int16* src = tensor.flat().data(); + std::copy(src, src + tensor.NumElements(), dst); + } else if (dtype == DT_INT8) { + const int8* src = tensor.flat().data(); + std::copy(src, src + tensor.NumElements(), dst); } else { - memcpy(dst, tensor_proto_array, weights->size_bytes()); + // dtype can only be DT_UINT8 at this point. + TFTRT_CHECK_EQ_TYPE(dtype, DT_UINT8); + const uint8* src = tensor.flat().data(); + std::copy(src, src + tensor.NumElements(), dst); } return Status::OK(); } @@ -2453,15 +3019,6 @@ tensorflow::Status ConvertConst(OpConverterParams* params) { "Constant node is expected to have empty input list: ", node_def.name()); } - TFAttrs attrs(node_def); - const DataType dtype = attrs.get("dtype"); - // We always convert the integer constants to kINT32, since TRT kINT8 is for - // quantized inference. - const DataType converted_dtype = - (dtype == DT_INT16 || dtype == DT_INT8 || dtype == DT_UINT8 ? DT_INT32 - : dtype); - nvinfer1::DataType trt_dtype; - TF_RETURN_IF_ERROR(ConvertDType(converted_dtype, &trt_dtype)); // Create shaped weights as output const auto& tensor_proto = node_def.attr().at("value").tensor(); @@ -2471,78 +3028,18 @@ tensorflow::Status ConvertConst(OpConverterParams* params) { node_def.name()); } - TRT_ShapedWeights weights(converted_dtype); - if (tensor.NumElements() == 0) { - // Do nothing. - } else if (!tensor_proto.float_val().empty()) { - TF_RETURN_IF_ERROR(TfTensorToTrtWeights( - converted_dtype, tensor, tensor_proto.float_val().begin(), - tensor_proto.float_val_size(), params->weight_store, &weights)); - } else if (!tensor_proto.int_val().empty()) { - TF_RETURN_IF_ERROR(TfTensorToTrtWeights( - converted_dtype, tensor, tensor_proto.int_val().begin(), - tensor_proto.int_val_size(), params->weight_store, &weights)); - } else if (!tensor_proto.half_val().empty()) { - // TODO(aaroey): implement fp16 conversion. - return errors::Unimplemented("fp16 constant is not supported yet."); - } else if (!tensor_proto.tensor_content().empty()) { - // TODO(aaroey): fp16 will remain in half format and is not converted to - // fp32, but the converter currently uses all float weights as fp32. Fix - // this. - const auto& content = tensor_proto.tensor_content(); - if (content.size() > 0) { - const int dtype_size = tensorflow::DataTypeSize(dtype); - if (content.size() % dtype_size != 0) { - return errors::FailedPrecondition("Tensor content size ", - content.size(), - " is not a multiple of ", dtype_size); - } - nvinfer1::Dims weights_dim; - TF_RETURN_IF_ERROR(GetTensorDimsWithProtoShape( - tensor, content.size() / dtype_size, &weights_dim)); - const int64_t size_bytes = TrtDimsNumElements(weights_dim) * dtype_size; - if (content.size() != size_bytes) { - return errors::FailedPrecondition( - "Tensor size and TensorProto content size mismatch: ", size_bytes, - " vs ", content.size()); - } else if (tensor.NumElements() != content.size() / dtype_size) { - return errors::FailedPrecondition( - "Tensor elements count and TensorProto content size mismatch: ", - tensor.NumElements(), " vs ", content.size() / dtype_size); - } - weights = - params->weight_store->GetTempWeights(converted_dtype, weights_dim); - if (dtype_size == tensorflow::DataTypeSize(converted_dtype)) { - port::CopyToArray(content, static_cast( - const_cast(weights.GetValues()))); - } else { - // Copy out the weights as original data type. - std::vector temp_weights(content.size()); - port::CopyToArray(content, - reinterpret_cast(temp_weights.data())); - int32* dst = - static_cast(const_cast(weights.GetValues())); - // Copy to the weight store as converted data type. - if (dtype == DT_INT16) { - int16* data = reinterpret_cast(temp_weights.data()); - std::copy(data, data + tensor.NumElements(), dst); - } else if (dtype == DT_INT8) { - int8* data = reinterpret_cast(temp_weights.data()); - std::copy(data, data + tensor.NumElements(), dst); - } else if (dtype == DT_UINT8) { - uint8* data = reinterpret_cast(temp_weights.data()); - std::copy(data, data + tensor.NumElements(), dst); - } else { - return errors::FailedPrecondition( - "Unexpected data type: ", DataTypeString(dtype), - " at: ", node_def.name()); - } - } - } - } else { - return errors::Unimplemented("Not supported constant type, at ", - node_def.name()); + TFAttrs attrs(node_def); + const DataType dtype = attrs.get("dtype"); + if (dtype != tensor.dtype()) { + return errors::InvalidArgument("DataType mismatch between attr (", + DataTypeString(dtype), ") and tensor (", + DataTypeString(tensor.dtype()), ")"); } + + TRT_ShapedWeights weights; + TF_RETURN_IF_ERROR( + TfTensorToTrtWeights(tensor, params->weight_store, &weights)); + if (params->outputs != nullptr) { params->outputs->push_back(TRT_TensorOrWeights(weights)); } @@ -2560,9 +3057,13 @@ tensorflow::Status ConvertIdentity(OpConverterParams* params) { Status ConvertBinary(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + // TODO(tmorris): Enable once false is updated to mean either tensor or weight + // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", + // false}})); if (inputs.size() != 2) { - return errors::InvalidArgument("Binary ops require two inputs, at ", - node_def.name()); + return tensorflow::errors::InvalidArgument( + node_def.op(), " got ", inputs.size(), " inputs but expected 2, at ", + node_def.name()); } // Constant folding should have been done by TensorFlow @@ -2601,62 +3102,104 @@ Status ConvertBinary(OpConverterParams* params) { return status; } -tensorflow::Status ConvertUnary(OpConverterParams* params) { +tensorflow::Status ConvertRsqrt(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - static const std::unordered_map ops{ - {"Neg", nvinfer1::UnaryOperation::kNEG}, - {"Exp", nvinfer1::UnaryOperation::kEXP}, - {"Log", nvinfer1::UnaryOperation::kLOG}, - {"Sqrt", nvinfer1::UnaryOperation::kSQRT}, - {"Abs", nvinfer1::UnaryOperation::kABS}, - {"Reciprocal", nvinfer1::UnaryOperation::kRECIP}, - }; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); + if (params->validation_only) return tensorflow::Status::OK(); - if (inputs.size() != 1) { - return tensorflow::errors::FailedPrecondition( - "Unary ops require single tensor input, at ", node_def.name()); + // TODO(tmorris): params->converter is null during validation. Allow + // precision_mode and use_calibration to be accessed during validation and + // include this check in validation. + // We will need a quantization range for intermediate tensor if not using + // calibration. + // + // x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x) + // ^ + // need range here + if (params->converter->precision_mode() == TrtPrecisionMode::INT8 && + !params->converter->use_calibration()) { + return errors::Unimplemented( + "Intermediate quantization range cannot be determined without" + " calibration for Rsqrt, consider replacing with " + "Sqrt -> FakeQuant -> Reciprocal ops, at ", + node_def.name()); } + // Start conversion. + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + // Sqrt + nvinfer1::IUnaryLayer* sqrt_layer = params->converter->network()->addUnary( + *const_cast(tensor), nvinfer1::UnaryOperation::kSQRT); + TFTRT_RETURN_ERROR_IF_NULLPTR(sqrt_layer, node_def.name()); + // Recip + nvinfer1::IUnaryLayer* recip_layer = params->converter->network()->addUnary( + *sqrt_layer->getOutput(0), nvinfer1::UnaryOperation::kRECIP); + TFTRT_RETURN_ERROR_IF_NULLPTR(recip_layer, node_def.name()); + params->outputs->push_back(TRT_TensorOrWeights(recip_layer->getOutput(0))); + return tensorflow::Status::OK(); +} - // TODO(jie): check type - const nvinfer1::ITensor* tensor = nullptr; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), inputs.at(0).GetTrtDims(), &tensor)); +const std::unordered_map* +UnaryOperationMap() { + static auto* const m = + new std::unordered_map({ + {"Neg", nvinfer1::UnaryOperation::kNEG}, + {"Exp", nvinfer1::UnaryOperation::kEXP}, + {"Log", nvinfer1::UnaryOperation::kLOG}, + {"Sqrt", nvinfer1::UnaryOperation::kSQRT}, + {"Abs", nvinfer1::UnaryOperation::kABS}, + {"Reciprocal", nvinfer1::UnaryOperation::kRECIP}, +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + {"Sin", nvinfer1::UnaryOperation::kSIN}, + {"Cos", nvinfer1::UnaryOperation::kCOS}, + {"Tan", nvinfer1::UnaryOperation::kTAN}, + {"Sinh", nvinfer1::UnaryOperation::kSINH}, + {"Cosh", nvinfer1::UnaryOperation::kCOSH}, + {"Asin", nvinfer1::UnaryOperation::kASIN}, + {"Acos", nvinfer1::UnaryOperation::kACOS}, + {"Atan", nvinfer1::UnaryOperation::kATAN}, + {"Asinh", nvinfer1::UnaryOperation::kASINH}, + {"Acosh", nvinfer1::UnaryOperation::kACOSH}, + {"Atanh", nvinfer1::UnaryOperation::kATANH}, + {"Ceil", nvinfer1::UnaryOperation::kCEIL}, + {"Floor", nvinfer1::UnaryOperation::kFLOOR}, +#endif + }); + return m; +} - nvinfer1::IUnaryLayer* layer; - if (node_def.op() == "Rsqrt") { - // We will need a quantization range for intermediate tensor if not using - // calibration. - // - // x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x) - // ^ - // need range here - if (params->converter->precision_mode() == INT8MODE && - !params->converter->use_calibration()) { - return errors::Unimplemented( - "Intermediate quantization range cannot be determined without" - " calibration for Rsqrt, consider replacing with " - "Sqrt -> FakeQuant -> Reciprocal ops, at ", - node_def.name()); - } - layer = params->converter->network()->addUnary( - *const_cast(tensor), - nvinfer1::UnaryOperation::kSQRT); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - tensor = layer->getOutput(0); - layer = params->converter->network()->addUnary( - *const_cast(tensor), - nvinfer1::UnaryOperation::kRECIP); - } else if (ops.count(node_def.op()) != 0) { - layer = params->converter->network()->addUnary( - *const_cast(tensor), ops.at(node_def.op())); - } else { - return tensorflow::errors::InvalidArgument( - "Binary op: ", node_def.op(), " not supported, at ", node_def.name()); +tensorflow::Status ConvertUnary(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); + auto op_pair = UnaryOperationMap()->find(node_def.op()); + if (op_pair == UnaryOperationMap()->end()) { + return tensorflow::errors::Unimplemented( + "Unary op: ", node_def.op(), " not supported at: ", node_def.name()); } + if (params->validation_only) return tensorflow::Status::OK(); + // Start conversion. + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary( + *const_cast(tensor), op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + // Set quantization ranges. + if (node_def.op() == "Sin" || node_def.op() == "Cos") { + params->converter->ProvideQuantizationRange(output_tensor, -1.0f, 1.0f); + } else if (node_def.op() == "Asin" || node_def.op() == "Atan") { + params->converter->ProvideQuantizationRange(output_tensor, -M_PI_2, M_PI_2); + } else if (node_def.op() == "Acos") { + params->converter->ProvideQuantizationRange(output_tensor, 0.0f, M_PI); + } else if (node_def.op() == "Neg" || node_def.op() == "Abs") { + // Neg and Abs will have same range as input since TRT uses symmetric + // quantization. + // TODO(tmorris): Should we infer ranges for Ceil and Floor as well? + params->converter->MarkQuantizationRangesAsInferrable( + const_cast(tensor), output_tensor); + } params->outputs->push_back( TRT_TensorOrWeights(const_cast(output_tensor))); return tensorflow::Status::OK(); @@ -2665,14 +3208,7 @@ tensorflow::Status ConvertUnary(OpConverterParams* params) { tensorflow::Status ConvertSquare(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 1) { - return tensorflow::errors::InvalidArgument("Square expects one input, at ", - node_def.name()); - } - if (inputs.at(0).is_weights()) { - return tensorflow::errors::Unimplemented( - "Square is only implemented for tensors, at ", node_def.name()); - } + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); if (params->validation_only) return Status::OK(); // Constant 2 with same rank as input @@ -2685,18 +3221,15 @@ tensorflow::Status ConvertSquare(OpConverterParams* params) { auto weights_ptr = static_cast(const_cast(weights.GetValues())); weights_ptr[0] = 2.f; - nvinfer1::IConstantLayer* const2_layer = - params->converter->network()->addConstant(dims, weights.GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(const2_layer, node_def.name()); + nvinfer1::ITensor* const2_tensor = + params->converter->CreateConstantLayer(weights, dims); + TFTRT_RETURN_ERROR_IF_NULLPTR(const2_tensor, node_def.name()); // ElementWise Pow Operation - const nvinfer1::ITensor* tensor_l = inputs.at(0).tensor(); - const nvinfer1::ITensor* tensor_r = const2_layer->getOutput(0); nvinfer1::IElementWiseLayer* layer = params->converter->network()->addElementWise( - *const_cast(tensor_l), - *const_cast(tensor_r), - nvinfer1::ElementWiseOperation::kPOW); + *const_cast(inputs.at(0).tensor()), + *const2_tensor, nvinfer1::ElementWiseOperation::kPOW); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); @@ -2707,11 +3240,8 @@ tensorflow::Status ConvertSquare(OpConverterParams* params) { tensorflow::Status ConvertReduce(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at", node_def.name()); - } + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"axis", true}})); const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); TRT_ShapedWeights index_list = inputs.at(1).weights(); @@ -2772,12 +3302,8 @@ tensorflow::Status ConvertReduce(OpConverterParams* params) { tensorflow::Status ConvertPad(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - // TODO(aaroey): make a routine for this check and reuse it. - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at", node_def.name()); - } + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"tensor", false}, {"paddings", true}})); // Implement tensor binaryOp weight [channel wise] for now; const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); @@ -2814,7 +3340,7 @@ tensorflow::Status ConvertPad(OpConverterParams* params) { } // No padding at all, we should exit - if (pad_index.size() == 0) { + if (pad_index.empty()) { params->outputs->push_back(inputs.at(0)); return tensorflow::Status::OK(); } @@ -2837,6 +3363,7 @@ tensorflow::Status ConvertPad(OpConverterParams* params) { return tensorflow::errors::Unimplemented( "Padding layer does not support padding on dimension 1 and 3 yet"); } + if (params->validation_only) return Status::OK(); bool legit_pad = true; nvinfer1::DimsHW pre_padding(0, 0); @@ -2940,6 +3467,7 @@ tensorflow::Status ConvertConcat(OpConverterParams* params) { inputs_vec.push_back(tensor_i); } + if (params->validation_only) return tensorflow::Status::OK(); // nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); nvinfer1::IConcatenationLayer* layer = @@ -2956,17 +3484,30 @@ tensorflow::Status ConvertConcat(OpConverterParams* params) { tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, + {"scale", true}, + {"offset", true}, + {"mean", true}, + {"variance", true}})); TFAttrs attrs(node_def); float epsilon = attrs.get("epsilon"); auto data_format = attrs.get("data_format"); if (data_format != "NCHW") { return tensorflow::errors::Unimplemented( - "only data_format=NCHW is supported, at " + node_def.name()); + node_def.op(), " only supports data_format=NCHW, at ", node_def.name()); } bool is_training = attrs.get("is_training"); if (is_training) { + // Trying to use batchnorm in training mode is a very common problem. + // Because the error message will only be printed in VLOG(1) by the + // segmenter, we issue a special warning so that users will actually see it. + LOG(WARNING) << node_def.op() << " only supports is_training=false. If you " + << "are using Keras, please call " + << "keras.backend.set_learning_phase(0) before constructing " + << "your model. At " << node_def.name(); return tensorflow::errors::Unimplemented( - "only is_training=false is supported, at " + node_def.name()); + node_def.op(), " only supports is_training=false, at ", + node_def.name()); } nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); @@ -2981,7 +3522,7 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { for (int i = 1; i < 5; i++) { if (inputs.at(i).weights().type_ != parameter_type) { return tensorflow::errors::Unimplemented( - "Inconsistent parameter type for batchnormis not supported, at: " + + "Inconsistent parameter type for batchnorm is not supported, at: " + node_def.name()); } } @@ -2989,7 +3530,7 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { TRT_ShapedWeights dummy_power_weights(parameter_type); size_t nweight = 0; for (int i = 1; i < 5; i++) { - nweight = std::max(nweight, (size_t)inputs.at(i).weights().count()); + nweight = std::max(nweight, inputs.at(i).weights().count()); } TRT_ShapedWeights* ptr_shape_weights = nullptr; for (int i = 1; i < 5; i++) { @@ -3001,6 +3542,8 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { "Inconsistent batchnorm parameter count, at: " + node_def.name()); } } + if (params->validation_only) return Status::OK(); + // We could technically have two weights with different shape. // that requires two addScale op, arguably less performant TRT_ShapedWeights combined_scale_weights = @@ -3072,6 +3615,29 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { return tensorflow::Status::OK(); } +tensorflow::Status ConvertGather(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"params", false}, {"indices", false}, {"axis", true}})); + absl::Span axis = inputs.at(2).weights().GetSpan(); + if (axis.size() != 1) { + return tensorflow::errors::InvalidArgument( + "Axis for GatherV2 must be a scalar, at ", node_def.name()); + } + int trt_axis = 0; + TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims, + node_def.name(), &trt_axis)); + if (params->validation_only) return Status::OK(); + + nvinfer1::IGatherLayer* layer = params->converter->network()->addGather( + *const_cast(inputs.at(0).tensor()), + *const_cast(inputs.at(1).tensor()), trt_axis); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return Status::OK(); +} + tensorflow::Status ConvertMatMulHelper(OpConverterParams* params, TRT_TensorOrWeights tensor_input, TRT_ShapedWeights weights_raw, @@ -3122,14 +3688,9 @@ tensorflow::Status ConvertMatMulHelper(OpConverterParams* params, tensorflow::Status ConvertMatMul(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return errors::InvalidArgument("Input expects tensor and weights, at ", - node_def.name()); - } + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"a", false}, {"b", true}})); TFAttrs attrs(node_def); - // TODO(jie): INT32 should be converted? tensorflow::DataType tf_dtype = attrs.get("T"); if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) { return errors::Unimplemented("Data type is not supported, for node ", @@ -3153,9 +3714,16 @@ tensorflow::Status ConvertMatMul(OpConverterParams* params) { tensorflow::Status ConvertBatchMatMul(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + // TODO(tmorris): Enable once false is updated to mean either tensor or weight + // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", + // false}})); + if (inputs.size() != 2) { + return tensorflow::errors::InvalidArgument( + node_def.op(), " got ", inputs.size(), " inputs but expected 2, at ", + node_def.name()); + } TFAttrs attrs(node_def); - // TODO(jie): INT32 should be converted? tensorflow::DataType tf_dtype = attrs.get("T"); if (tf_dtype != tensorflow::DataType::DT_FLOAT && tf_dtype != tensorflow::DataType::DT_HALF) { @@ -3225,6 +3793,7 @@ tensorflow::Status ConvertBatchMatMul(OpConverterParams* params) { tensorflow::Status ConvertSoftmax(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"logits", false}})); const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); int nbDims = tensor->getDimensions().nbDims; @@ -3233,6 +3802,8 @@ tensorflow::Status ConvertSoftmax(OpConverterParams* params) { "TensorRT Softmax cannot apply on batch dimension, at" + node_def.name()); } + if (params->validation_only) return Status::OK(); + nvinfer1::ISoftMaxLayer* layer = params->converter->network()->addSoftMax( *const_cast(tensor)); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); @@ -3248,31 +3819,36 @@ tensorflow::Status ConvertSoftmax(OpConverterParams* params) { tensorflow::Status ConvertTopK(OpConverterParams* params) { const auto& inputs = params->inputs; + if (inputs.size() != 2 || !inputs.at(0).is_tensor() || + !inputs.at(1).is_weights()) { + return errors::InvalidArgument("Input expects tensor and weights, at ", + params->node_def.name()); + } + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"k", true}})); const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - - int nbDims = tensor->getDimensions().nbDims; - if (nbDims == 0) { - return tensorflow::errors::InvalidArgument( - "TensorRT TopK cannot apply on batch dimension, at" + node_def.name()); + const int num_dims = tensor->getDimensions().nbDims; + if (num_dims == 0) { + return errors::InvalidArgument( + "TensorRT TopK cannot apply on batch dimension, at", node_def.name()); } TRT_ShapedWeights k_w = inputs.at(1).weights(); - int k = *(static_cast(const_cast(k_w.GetValues()))); - - nvinfer1::TopKOperation op; - uint32_t reducedAxes = 0; - if (node_def.op() == "TopKV2") { - op = nvinfer1::TopKOperation::kMAX; - reducedAxes |= 1 << (nbDims - 1); - } else { - return tensorflow::errors::Unimplemented( - "Operation: " + node_def.op() + - " not implemented, at: " + node_def.name()); + if (k_w.count() != 1) { + return errors::InvalidArgument("k value of TopK should be a scalar, at", + node_def.name()); } + // Note that ITopKLayer always have sorted outputs, so we don't need to handle + // the 'sorted' attribute of the node. + if (params->validation_only) return Status::OK(); + const nvinfer1::TopKOperation op = nvinfer1::TopKOperation::kMAX; + const int k = *(static_cast(const_cast(k_w.GetValues()))); + const uint32_t reduce_axes = 1 << (num_dims - 1); nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK( - *const_cast(tensor), op, k, reducedAxes); + *const_cast(tensor), op, k, reduce_axes); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_value_tensor = layer->getOutput(0); @@ -3286,14 +3862,25 @@ static void RegisterValidatableOpConverters( std::unordered_map* registration) { // TODO(laigd): support all op types. (*registration)["BiasAdd"] = ConvertBiasAdd; + (*registration)["ConcatV2"] = ConvertConcat; (*registration)["Const"] = ConvertConst; - (*registration)["Transpose"] = ConvertTranspose; - (*registration)["Reshape"] = ConvertReshape; + (*registration)["Conv2D"] = ConvertConv2D; + (*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput; + (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; + (*registration)["ExpandDims"] = ConvertExpandDims; + (*registration)["GatherV2"] = ConvertGather; + (*registration)["LeakyRelu"] = ConvertLeakyRelu; (*registration)["MatMul"] = ConvertMatMul; + (*registration)["Pad"] = ConvertPad; (*registration)["Relu6"] = ConvertRelu6; + (*registration)["Reshape"] = ConvertReshape; + (*registration)["Rsqrt"] = ConvertRsqrt; + (*registration)["Slice"] = ConvertSlice; (*registration)["Square"] = ConvertSquare; - (*registration)["ExpandDims"] = ConvertExpandDims; (*registration)["Squeeze"] = ConvertSqueeze; + (*registration)["StridedSlice"] = ConvertStridedSlice; + (*registration)["Transpose"] = ConvertTranspose; + (*registration)["TopKV2"] = ConvertTopK; for (auto quantization_op_type : {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", @@ -3307,6 +3894,15 @@ static void RegisterValidatableOpConverters( for (auto activation_op_type : {"Relu", "Sigmoid", "Tanh"}) { (*registration)[activation_op_type] = ConvertActivation; } + for (auto pool_op_type : {"AvgPool", "MaxPool"}) { + (*registration)[pool_op_type] = ConvertPool; + } + for (auto normalization_op_type : {"FusedBatchNorm", "FusedBatchNormV2"}) { + (*registration)[normalization_op_type] = ConvertFusedBatchNorm; + } + for (auto unary_op_pair : *UnaryOperationMap()) { + (*registration)[unary_op_pair.first] = ConvertUnary; + } } void TrtNodeValidator::RegisterOpValidators() { @@ -3315,29 +3911,10 @@ void TrtNodeValidator::RegisterOpValidators() { void Converter::RegisterOpConverters() { RegisterValidatableOpConverters(&op_registry_); - - op_registry_["Conv2D"] = ConvertConv2D; - op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; - op_registry_["MaxPool"] = ConvertPool; - op_registry_["AvgPool"] = ConvertPool; // TODO(ben,jie): this is a temp hack. op_registry_["Identity"] = ConvertIdentity; // Identity should be removed op_registry_["Snapshot"] = ConvertIdentity; // Snapshot should be removed - op_registry_["Pad"] = ConvertPad; - - op_registry_["ConcatV2"] = ConvertConcat; - op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; - op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; - - op_registry_["Rsqrt"] = ConvertUnary; - op_registry_["Reciprocal"] = ConvertUnary; - op_registry_["Exp"] = ConvertUnary; - op_registry_["Log"] = ConvertUnary; - op_registry_["Sqrt"] = ConvertUnary; - op_registry_["Abs"] = ConvertUnary; - op_registry_["Neg"] = ConvertUnary; - op_registry_["Sum"] = ConvertReduce; op_registry_["Prod"] = ConvertReduce; op_registry_["Max"] = ConvertReduce; @@ -3345,14 +3922,13 @@ void Converter::RegisterOpConverters() { op_registry_["Mean"] = ConvertReduce; op_registry_["Softmax"] = ConvertSoftmax; op_registry_["BatchMatMul"] = ConvertBatchMatMul; - op_registry_["TopKV2"] = ConvertTopK; plugin_converter_ = ConvertPlugin; } tensorflow::Status ConvertGraphDefToEngine( - const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size, - size_t max_workspace_size_bytes, + const tensorflow::GraphDef& gdef, TrtPrecisionMode precision_mode, + int max_batch_size, size_t max_workspace_size_bytes, const std::vector& input_shapes, Logger* logger, nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator, @@ -3367,9 +3943,13 @@ tensorflow::Status ConvertGraphDefToEngine( builder->setMaxBatchSize(max_batch_size); builder->setMaxWorkspaceSize(max_workspace_size_bytes); builder->setGpuAllocator(allocator); - if (precision_mode == FP16MODE) { - builder->setHalf2Mode(true); - } else if (precision_mode == INT8MODE) { + if (precision_mode == TrtPrecisionMode::FP16) { + builder->setFp16Mode(true); + } else if (precision_mode == TrtPrecisionMode::INT8) { + // Setting FP16 mode as well allows TRT to also consider FP16 kernels and + // use them in situations where they are faster than INT8 or where INT8 is + // not supported for a given layer. + builder->setFp16Mode(true); builder->setInt8Mode(true); if (use_calibration) { builder->setInt8Calibrator(calibrator); @@ -3389,15 +3969,14 @@ tensorflow::Status ConvertGraphDefToEngine( // Build the network VLOG(1) << "Starting engine conversion "; Converter converter(trt_network.get(), precision_mode, use_calibration); - std::vector> output_tensors; + std::vector output_tensors; // Graph nodes are already topologically sorted during construction for (const auto& node_def : gdef.node()) { string node_name = node_def.name(); VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op(); - if (tensorflow::str_util::StartsWith(node_name, kInputPHName) && - (node_def.op() == "Placeholder")) { + if (IsEngineInput(node_name) && (node_def.op() == "Placeholder")) { int32 slot_number = -1; - if (!tensorflow::strings::safe_strto32( + if (!tensorflow::strings::safe_strto32( // non-absl ok node_name.c_str() + strlen(kInputPHName), &slot_number)) { return tensorflow::errors::InvalidArgument( "Failed to parse slot number from ", node_name); @@ -3423,18 +4002,23 @@ tensorflow::Status ConvertGraphDefToEngine( // engines offline, by calling sess.run() and cache/serialize the engines. TF_RETURN_IF_ERROR( converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size)); - } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) && - (node_def.op() == "Identity")) { + } else if (IsEngineOutput(node_name) && (node_def.op() == "Identity")) { int32 slot_number = -1; - if (!tensorflow::strings::safe_strto32( + if (!tensorflow::strings::safe_strto32( // non-absl ok node_name.c_str() + strlen(kOutputPHName), &slot_number)) { return tensorflow::errors::InvalidArgument( "Failed to parse slot number from ", node_name); } + // Get output type that TensorFlow expects + TFAttrs attrs(node_def); + tensorflow::DataType tf_dtype = attrs.get("T"); + nvinfer1::DataType trt_dtype; + TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype)); if (output_tensors.size() <= slot_number) { output_tensors.resize(slot_number + 1); } - output_tensors.at(slot_number) = {node_def.input(0), node_name}; + output_tensors.at(slot_number) = {node_def.input(0), node_name, + trt_dtype}; } else { VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op(); @@ -3460,8 +4044,7 @@ tensorflow::Status ConvertGraphDefToEngine( tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, - const std::set& subgraph_node_names, - const std::vector& subgraph_node_ids, // In topological order + const std::vector& subgraph_nodes, // In topological order std::vector* connections, tensorflow::GraphDef* segment_def, string* common_scope) { std::set marker_nodes; @@ -3524,8 +4107,10 @@ tensorflow::Status ConvertSegmentToGraphDef( marker_nodes.insert(node_name); auto seg_node = segment_def->add_node(); tensorflow::NodeDefBuilder builder(node_name, "Identity"); - auto status = builder.Input(connection.inside_node_name, 0, dtype) - .Finalize(seg_node); + auto status = + builder + .Input(connection.inside_node_name, connection.inside_port, dtype) + .Finalize(seg_node); VLOG(1) << "Constructing output " << node_name << " for the edge " << connection.inside_node_name << ":" << connection.inside_port << " -> " << connection.outside_node_name << ":" @@ -3535,13 +4120,12 @@ tensorflow::Status ConvertSegmentToGraphDef( std::unordered_map old_to_new_id_map; // Copy internal nodes to new graphdef - string local_scope = graph->FindNodeId(*subgraph_node_ids.begin())->name(); - for (const auto node_id : subgraph_node_ids) { - const auto node = graph->FindNodeId(node_id); + string local_scope = subgraph_nodes.front()->name(); + for (const Node* node : subgraph_nodes) { local_scope = GetCommonNameScope(local_scope, node->name()); - old_to_new_id_map[node_id] = segment_def->node_size(); + old_to_new_id_map[node->id()] = segment_def->node_size(); auto snode = segment_def->add_node(); - snode->CopyFrom(node->def()); + *snode = node->def(); VLOG(2) << "Copying " << snode->name() << " to subgraph"; } // Update the inputs of the new input nodes to point to placeholder nodes. @@ -3557,6 +4141,11 @@ tensorflow::Status ConvertSegmentToGraphDef( << placeholder_name; snode->set_input(connection.inside_port, placeholder_name); } + std::set subgraph_node_names; + for (const Node* node : subgraph_nodes) { + subgraph_node_names.insert(node->name()); + } + // Remove control inputs that are not inside the segment. for (int i = 0; i < segment_def->node_size(); ++i) { auto snode = segment_def->mutable_node(i); @@ -3567,7 +4156,7 @@ tensorflow::Status ConvertSegmentToGraphDef( TensorId input = ParseTensorName(snode->input(input_idx)); if (!subgraph_node_names.count( string(input.first.data(), input.first.size())) && - !str_util::StartsWith(input.first, kInputPHName)) { + !IsEngineInput(input.first)) { if (input.second == Graph::kControlSlot) { VLOG(1) << "... removing control inputs " << input.first << " from subgraph."; diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h similarity index 86% rename from tensorflow/contrib/tensorrt/convert/convert_nodes.h rename to tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 54e19b73957bccdae2b23bd3556de9ad00b864e5..7b37173090519ff6fadd956942d7ea12a0644981 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ #include #include @@ -22,11 +22,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -92,7 +92,7 @@ struct EngineInfo { EngineInfo() : engine_type(EngineType::TRTStatic), max_workspace_size_bytes(0), - precision_mode(FP32MODE), + precision_mode(TrtPrecisionMode::FP32), use_calibration(true) {} string engine_name; @@ -109,7 +109,7 @@ struct EngineInfo { int64 max_workspace_size_bytes; int maximum_cached_engines; std::vector cached_engine_batches; - int precision_mode; + TrtPrecisionMode precision_mode; bool use_calibration; }; @@ -128,8 +128,7 @@ struct EngineInfo { tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, - const std::set& subgraph_node_names, - const std::vector& subgraph_node_ids, + const std::vector& subgraph_nodes, std::vector* connections, tensorflow::GraphDef* segment_def, string* common_scope); @@ -142,8 +141,8 @@ tensorflow::Status ConvertSegmentToGraphDef( // is successful. This is different than successfully building the engine: // building can still fail afterwards. tensorflow::Status ConvertGraphDefToEngine( - const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size, - size_t max_workspace_size_bytes, + const tensorflow::GraphDef& gdef, TrtPrecisionMode precision_mode, + int max_batch_size, size_t max_workspace_size_bytes, const std::vector& input_shapes, Logger* logger, nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator, @@ -159,7 +158,10 @@ class OutputEdgeValidator { bool operator()(const tensorflow::Edge* out_edge) const; }; +string DebugString(const nvinfer1::DimensionType type); +string DebugString(const nvinfer1::DataType trt_dtype); string DebugString(const nvinfer1::Dims& dims); +string DebugString(const nvinfer1::Permutation& permutation, int len); string DebugString(const nvinfer1::ITensor& tensor); int64_t TrtDimsNumElements(const nvinfer1::Dims& dims); @@ -176,6 +178,8 @@ class TRT_ShapedWeights { nvinfer1::Weights GetTrtWeights() const; + // Returns the raw pointer to the underlying buffer which holds the weights + // value. void* GetValues() const { return const_cast(tensor_.tensor_data().data()); } @@ -186,6 +190,17 @@ class TRT_ShapedWeights { string DebugString() const; + template + absl::Span GetSpan() const { + return absl::Span(tensor_.flat().data(), count()); + } + + template + std::vector ToVector() const { + auto span = GetSpan(); + return std::vector(span.data(), span.data() + span.size()); + } + // TODO(aaroey): make these private. nvinfer1::Dims shape_; // Note: shape.type[] is not used. tensorflow::DataType type_; @@ -195,6 +210,10 @@ class TRT_ShapedWeights { // underlying buffer. TRT_ShapedWeights(DataType type, nvinfer1::Dims dims, Tensor tensor); + // All weights should be stored inside TrtWeightStore to make sure lifetime of + // all the underlying tensors are available until the engine is built. For + // this reason, tensor_ should never be reassigned to a different value that + // is not already present in the TrtWeightStore. Tensor tensor_; friend class TrtWeightStore; @@ -394,8 +413,21 @@ class TrtNodeValidator { // Class to convert TF nodes to TRT network. class Converter { public: - Converter(nvinfer1::INetworkDefinition* trt_network, int precision_mode, - bool use_calibration); + // Used for Converter::RenameAndMarkOutputTensors() + struct EngineOutputInfo { + // The TRT tensor name which produces the output. + string source_tensor_name; + // The TensorFlow node name which is receiving the output from the TRT + // engine. This should always be the Identity node created in + // ConvertSegmentToGraphDef. + string dest_node_name; + // Output type. TensorRT requires this to be explicitly set for engine + // outputs. + nvinfer1::DataType trt_dtype; + }; + + Converter(nvinfer1::INetworkDefinition* trt_network, + TrtPrecisionMode precision_mode, bool use_calibration); ////////////////////////////////////////////////////////////////////////////// // Methods used by the TRT engine builder to build a TRT network from a TF @@ -409,13 +441,10 @@ class Converter { Status AddInputTensor(const string& name, nvinfer1::DataType dtype, const nvinfer1::Dims& dims, int batch_size); - // Mark the tensors with names specified by output_tensors[i].first as output - // of the TRT network, and set their names in the TRT network as - // output_tensors[i].second. The tensor names (output_tensors[i].first) are - // standard TF tensor names, i.e. node names followed by output slot number - // (or just the node name if the tensor is the first output of the node). + // Mark the tensors with names specified by source_tensor_name as output of + // the TRT network, and set their names in the TRT network as dest_node_name. Status RenameAndMarkOutputTensors( - const std::vector>& output_tensors); + const std::vector& output_tensors); ////////////////////////////////////////////////////////////////////////////// // Methods used by op converters to convert individual TF node and add layers @@ -426,7 +455,7 @@ class Converter { nvinfer1::INetworkDefinition* network() { return trt_network_; } // What precision are we targeting? - int precision_mode() const { return precision_mode_; } + TrtPrecisionMode precision_mode() const { return precision_mode_; } // Calibration will be or was previously performed on this network? bool use_calibration() const { return use_calibration_; } @@ -469,6 +498,11 @@ class Converter { nvinfer1::Dims* operand_l_new_dims, nvinfer1::Dims* operand_r_new_dims) const; + // Creates an IConstantLayer using 'weights' whose dimensions are specified by + // 'dims', and returns the output ITensor. + nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights, + const nvinfer1::Dims& dims); + private: // Verify the provided batch_size is consistent with batch_size_ and update it // if necessary. @@ -523,7 +557,7 @@ class Converter { std::vector> quantization_infer_; - const int precision_mode_; + const TrtPrecisionMode precision_mode_; const bool use_calibration_; @@ -537,6 +571,9 @@ class Converter { friend class OpConverterTest; }; +// Map of all supported UnaryOperations +const std::unordered_map* UnaryOperationMap(); + } // namespace convert } // namespace tensorrt } // namespace tensorflow @@ -544,4 +581,4 @@ class Converter { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc similarity index 56% rename from tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index c37a43dd5def9daf3c5d70720c6db2aab20db077..45afc76d758ab5052da78879b27380e2c1ccb5b9 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include #include @@ -21,11 +21,16 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT @@ -35,7 +40,9 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/config.pb.h" // NOLINT #include "tensorflow/core/public/session.h" @@ -50,9 +57,10 @@ namespace tensorflow { namespace tensorrt { namespace convert { -using ::tensorflow::strings::StrCat; +using absl::StrCat; using ::testing::ElementsAre; using ::testing::ElementsAreArray; +using ::testing::NanSensitiveFloatNear; // TODO(laigd): put this into some test utils file. void ExpectStatus(Status status, error::Code code = error::OK, @@ -152,7 +160,7 @@ void ExpectTrtDimsEqualsArray(const std::vector& lhs, } template -void ExpectArrayNear(const std::vector& lhs, const std::vector& rhs) { +void ExpectArrayNear(const std::vector& lhs, absl::Span rhs) { ASSERT_EQ(lhs.size(), rhs.size()); for (int i = 0; i < lhs.size(); i++) { EXPECT_FLOAT_EQ(lhs[i], rhs[i]); @@ -163,7 +171,7 @@ void ExpectArrayNear(const std::vector& lhs, const std::vector& rhs) { // EXPECT_FLOAT_EQ. template <> void ExpectArrayNear(const std::vector& lhs, - const std::vector& rhs) { + absl::Span rhs) { ASSERT_EQ(lhs.size(), rhs.size()); for (int i = 0; i < lhs.size(); i++) { EXPECT_FLOAT_EQ(Eigen::half_impl::half_to_float(lhs[i]), @@ -234,6 +242,16 @@ class FakeITensor : public nvinfer1::ITensor { float getDynamicRange() const override { return dynamic_range_; } #endif +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + bool dynamicRangeIsSet() const override { return true; } + + void resetDynamicRange() override {} + + float getDynamicRangeMin() const override { return 0.f; } + + float getDynamicRangeMax() const override { return 0.f; } +#endif + private: string name_; nvinfer1::Dims dims_; @@ -364,9 +382,6 @@ TEST(TRT_TensorOrWeights_Test, Basic) { EXPECT_EQ(false, ptr->is_tensor()); EXPECT_EQ(true, ptr->is_weights()); EXPECT_TRUE(TrtShapedWeightsEquals(weights, ptr->weights())); - - nvinfer1::Dims dims; - dims.nbDims = 0; ExpectTrtDimsEqualsArray({}, ptr->GetTrtDims()); } } @@ -481,8 +496,7 @@ class ConverterTest : public ::testing::Test { ConverterTest() { builder_.reset(nvinfer1::createInferBuilder(logger_)); network_.reset(builder_->createNetwork()); - converter_.reset(new Converter(network_.get(), - /*precision_mode=*/FP32MODE, + converter_.reset(new Converter(network_.get(), TrtPrecisionMode::FP32, /*use_calibration=*/false)); weight_store_ = &converter_->weight_store_; } @@ -784,7 +798,7 @@ TEST_F(ConverterTest, MaybeApplyQuantizationRanges) { // input -> infer1 -> infer2 -> infer3 FakeITensor input, infer_1, infer_2, infer_3; FakeITensor not_infer; - Converter int8_converter(/*trt_network=*/nullptr, INT8MODE, + Converter int8_converter(/*trt_network=*/nullptr, TrtPrecisionMode::INT8, /*use_calibration=*/true); int8_converter.ProvideQuantizationRange(&input, -5.0f, 5.0f); int8_converter.ProvideQuantizationRange(¬_infer, -100.0f, 100.0f); @@ -915,6 +929,97 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { "(tensor #dims 4 vs broadcast #dims 5)"); } +TEST_F(ConverterTest, CreateConstantLayer) { + for (auto dtype : {DT_FLOAT, DT_INT32}) { + TRT_ShapedWeights weights = + weight_store_->GetTempWeights(dtype, GetTestDims({2, 3, 5})); + nvinfer1::ITensor* tensor = + converter_->CreateConstantLayer(weights, GetTestDims({3, 10})); + ASSERT_NE(nullptr, tensor); + EXPECT_EQ(TfDataTypeToTrt(dtype), tensor->getType()) + << "Expected " << DebugString(TfDataTypeToTrt(dtype)) << " vs. actual " + << DebugString(tensor->getType()); + ExpectTrtDimsEqualsArray({3, 10}, tensor->getDimensions()); + } +} + +class ConvertGraphDefToEngineTest : public ::testing::Test { + public: + Status RunConvertGraphDefToEngine(Scope* s) { + GraphDef gdef; + TF_EXPECT_OK(s->ToGraphDef(&gdef)); + std::vector input_shapes; + int batch_size = -1; + for (const NodeDef& node : gdef.node()) { + absl::string_view node_name(node.name()); + if (str_util::ConsumePrefix(&node_name, kInputPHName)) { + int port = -1; + EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name(); + if (input_shapes.size() < port + 1) input_shapes.resize(port + 1); + input_shapes[port] = + PartialTensorShape(node.attr().at("shape").shape()); + if (batch_size == -1) { + batch_size = input_shapes[port].dim_size(0); + } else { + EXPECT_EQ(batch_size, input_shapes[port].dim_size(0)); + } + } + } + // TODO(laigd): execute the engine and get outputs. + return ConvertGraphDefToEngine( + gdef, TrtPrecisionMode::FP32, /*max_batch_size=*/1, + /*max_workspace_size_bytes=*/64 << 20, input_shapes, &logger_, + /*allocator=*/nullptr, /*calibrator=*/nullptr, &engine_, + /*use_calibration=*/false, /*convert_successfully=*/nullptr); + } + + protected: + TrtUniquePtrType engine_; + + private: + Logger logger_; +}; + +TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName(StrCat(kInputPHName, 0)), DT_FLOAT, + ops::Placeholder::Shape({1, 1})); + auto output = ops::Identity(s.WithOpName("identity1"), input); + output = ops::Identity(s.WithOpName("identity2"), output); + output = ops::Identity(s.WithOpName(StrCat(kOutputPHName, 0)), output); + // If the converter marks the input tensor as output tensor, the conversion + // below will fail with: + // > TensorRTOutputPH_0 cannot be both input and output + // > Network must have at least one output + TF_EXPECT_OK(RunConvertGraphDefToEngine(&s)); +} + +// Input/output data format for OpConverterTest::BuildAndRun(). +struct InputOutputData { + void* Buffer() const { + return const_cast(tensor.tensor_data().data()); + } + + size_t TotalBytes() const { return tensor.TotalBytes(); } + + const char* name; + Tensor tensor; +}; + +template +Tensor ConstructTensor(int data_size, const T& value = T()) { + std::vector values(data_size, value); + return test::AsTensor(values); +} + +using DataVec = std::vector; + +template +inline absl::Span GetSpanForData(const InputOutputData& data) { + const auto& tensor_map = data.tensor.flat(); + return absl::Span(tensor_map.data(), tensor_map.size()); +} + // Class to test various op converters, using both a TrtNodeValidator and // Converter. class OpConverterTest : public ::testing::Test { @@ -940,11 +1045,11 @@ class OpConverterTest : public ::testing::Test { builder_.reset(nvinfer1::createInferBuilder(logger_)); network_.reset(builder_->createNetwork()); builder_->setMaxBatchSize(1); + builder_->setMaxWorkspaceSize(1 << 26); // Reset the validator and converter. validator_.reset(new TrtNodeValidator); - converter_.reset(new Converter(network_.get(), - /*precision_mode=*/FP32MODE, + converter_.reset(new Converter(network_.get(), TrtPrecisionMode::FP32, /*use_calibration=*/false)); // Reset other related artifacts. @@ -953,14 +1058,14 @@ class OpConverterTest : public ::testing::Test { } // TODO(laigd): test fp16 and int8 support. - template - void BuildAndRun( - const std::vector>>& - input_data, - const char* output_name, std::vector* output_data) { + void BuildAndRun(const DataVec& input_data, DataVec* output_data) { // Mark the output tensor as TRT engine output. - TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors( - {{string(output_name), string(output_name)}})); + std::vector output_info; + for (const auto& data : *output_data) { + output_info.push_back( + {data.name, data.name, TfDataTypeToTrt(data.tensor.dtype())}); + } + TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info)); // Build the TRT engine. ASSERT_EQ(nullptr, engine_.get()); @@ -968,31 +1073,44 @@ class OpConverterTest : public ::testing::Test { CHECK_NOTNULL(engine_.get()); // Execute the TRT engine. - ASSERT_LE(input_data.size() + 1, 3); - void* buffers[3]; - for (const auto name_and_data : input_data) { - const int input_size = name_and_data.second.size() * sizeof(T); - const int input_index = engine_->getBindingIndex(name_and_data.first); - ASSERT_EQ(0, cudaMalloc(&buffers[input_index], input_size)); - ASSERT_EQ( - 0, cudaMemcpyAsync(buffers[input_index], name_and_data.second.data(), - input_size, cudaMemcpyHostToDevice, stream_)); + const int num_bindings = input_data.size() + output_data->size(); + std::vector buffers(num_bindings); + + for (const auto& data : input_data) { + const int input_index = engine_->getBindingIndex(data.name); + ASSERT_EQ(0, cudaMalloc(&buffers[input_index], data.TotalBytes())); + ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], data.Buffer(), + data.TotalBytes(), cudaMemcpyHostToDevice, + stream_)); + } + struct SizeAndIndex { + SizeAndIndex(int in_size, int in_index) + : size(in_size), index(in_index) {} + int size; + int index; + }; + std::vector output_infos; + for (const auto& data : *output_data) { + const int output_index = engine_->getBindingIndex(data.name); + output_infos.emplace_back(data.TotalBytes(), output_index); + ASSERT_EQ(0, cudaMalloc(&buffers[output_index], data.TotalBytes())); } - const int output_size = output_data->size() * sizeof(T); - const int output_index = engine_->getBindingIndex(output_name); - ASSERT_EQ(0, cudaMalloc(&buffers[output_index], output_size)); - - ASSERT_EQ(engine_->getNbBindings(), input_data.size() + 1); - + ASSERT_EQ(engine_->getNbBindings(), num_bindings); TrtUniquePtrType execution_context( engine_->createExecutionContext()); - execution_context->enqueue(/*batchSize=*/1, buffers, stream_, nullptr); - ASSERT_EQ(0, cudaMemcpyAsync(output_data->data(), buffers[output_index], - output_size, cudaMemcpyDeviceToHost, stream_)); + execution_context->enqueue(/*batchSize=*/1, buffers.data(), stream_, + nullptr); + + for (int i = 0; i < output_infos.size(); ++i) { + const auto& output_info = output_infos[i]; + ASSERT_EQ(0, cudaMemcpyAsync(output_data->at(i).Buffer(), + buffers[output_info.index], output_info.size, + cudaMemcpyDeviceToHost, stream_)); + } cudaStreamSynchronize(stream_); - for (int i = 0; i < input_data.size() + 1; ++i) { + for (int i = 0; i < num_bindings; ++i) { ASSERT_EQ(0, cudaFree(buffers[i])); } } @@ -1111,6 +1229,30 @@ class OpConverterTest : public ::testing::Test { std::unordered_map validator_inputs_; }; +template +void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField* out) { + out->Clear(); + if (tensor.NumElements() == 0) return; + + // TensorProto does not need to have all the elements present and can truncate + // trailing elements with the same value for compressed representation. Such + // elements are derived based on the tensor shape. + const auto flat = tensor.flat(); + int64 last_index = 0; + for (int64 i = 0; i < tensor.NumElements(); ++i) { + if (flat(i) != flat(last_index)) { + last_index = i; + } + } + + int num_out_elements = last_index + 1; + out->Reserve(num_out_elements); + out->AddNAlreadyReserved(num_out_elements); + const T* src = flat.data(); + T* dst = out->mutable_data(); + std::copy(src, src + num_out_elements, dst); +} + template void TestConvertConst(OpConverterTest* test) { NodeDef node_def; @@ -1123,11 +1265,23 @@ void TestConvertConst(OpConverterTest* test) { const std::vector& expected_value) { test->Reset(); - auto& attr = *node_def.mutable_attr(); + TensorProto* tensor_attr = + (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor_attr->Clear(); + if (as_tensor_content) { - tensor.AsProtoTensorContent(attr["value"].mutable_tensor()); + tensor.AsProtoTensorContent(tensor_attr); } else { - tensor.AsProtoField(attr["value"].mutable_tensor()); + tensor.shape().AsProto(tensor_attr->mutable_tensor_shape()); + tensor_attr->set_dtype(tensor.dtype()); + + if (tensor.dtype() == DT_FLOAT) { + CopyTensorElements(tensor, tensor_attr->mutable_float_val()); + } else if (tensor.dtype() == DT_INT32) { + CopyTensorElements(tensor, tensor_attr->mutable_int_val()); + } else { + tensor.AsProtoField(tensor_attr); + } } test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; @@ -1140,8 +1294,7 @@ void TestConvertConst(OpConverterTest* test) { { // By default empty tensor will pick DT_FLOAT as data type and we fix it // here. - attr["value"].mutable_tensor()->set_dtype(dtype); - Tensor t; // Empty tensor. + Tensor t(dtype); // Empty tensor. reset_and_test(t, false, {}, {}); } { @@ -1160,6 +1313,22 @@ void TestConvertConst(OpConverterTest* test) { reset_and_test(t, false, {2, 3}, {1, 2, 3, 4, 5, 6}); reset_and_test(t, true, {2, 3}, {1, 2, 3, 4, 5, 6}); } + { + // Set all tensor elements to the same value. Such tensors are encoded + // using a single element list in tensor proto. + Tensor t = ::tensorflow::test::AsTensor({1, 1, 1, 1, 1, 1}, + TensorShape({2, 3})); + reset_and_test(t, false, {2, 3}, {1, 1, 1, 1, 1, 1}); + reset_and_test(t, true, {2, 3}, {1, 1, 1, 1, 1, 1}); + } + { + // Set trailing tensor elements to the same value. Such tensors are + // encoded by truncating all equal elements except the first one. + Tensor t = ::tensorflow::test::AsTensor({2, 2, 1, 1, 1, 1}, + TensorShape({2, 3})); + reset_and_test(t, false, {2, 3}, {2, 2, 1, 1, 1, 1}); + reset_and_test(t, true, {2, 3}, {2, 2, 1, 1, 1, 1}); + } } TEST_F(OpConverterTest, ConvertConst) { @@ -1189,7 +1358,7 @@ TEST_F(OpConverterTest, ConvertTranspose) { NodeDef node_def = MakeNodeDef("my_transpose", "Transpose", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Input expects tensor and weights, at my_transpose"); + "Transpose got 0 inputs but expected 2, at my_transpose"); } // Get the NodeDef for Transpose. @@ -1205,8 +1374,8 @@ TEST_F(OpConverterTest, ConvertTranspose) { AddTestTensor("input", {1, 2, 3}); AddTestTensor("weights", {3}); RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Input expects tensor and weights, at my_transpose"); + node_def, error::UNIMPLEMENTED, + "The input \"perm\" for Transpose must be a constant, at my_transpose"); } { // Transpose at batch dimension, should fail. @@ -1236,10 +1405,12 @@ TEST_F(OpConverterTest, ConvertTranspose) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions()); - std::vector output_data(6); - BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_transpose", - &output_data); - EXPECT_THAT(output_data, ElementsAre(1, 4, 2, 5, 3, 6)); + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_transpose", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(1, 4, 2, 5, 3, 6)); } } @@ -1249,7 +1420,7 @@ TEST_F(OpConverterTest, ConvertReshape) { NodeDef node_def = MakeNodeDef("my_reshape", "Reshape", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Input expects weights for shape, at my_reshape"); + "Reshape got 0 inputs but expected 2, at my_reshape"); } // Get the NodeDef for Reshape. @@ -1265,8 +1436,8 @@ TEST_F(OpConverterTest, ConvertReshape) { AddTestTensor("input", {1, 2, 3}); AddTestTensor("weights", {3}); RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Input expects weights for shape, at my_reshape"); + node_def, error::UNIMPLEMENTED, + "The input \"shape\" for Reshape must be a constant, at my_reshape"); } { // Reshape to scalar, should fail. @@ -1279,11 +1450,6 @@ TEST_F(OpConverterTest, ConvertReshape) { } struct TestParams { - TestParams(int input_batch_size, const std::vector& input_tensor_dims, - const std::vector& input_shape) - : batch_size(input_batch_size), - tensor_dims(input_tensor_dims), - shape(input_shape) {} int batch_size; std::vector tensor_dims; std::vector shape; @@ -1326,10 +1492,12 @@ TEST_F(OpConverterTest, ConvertReshape) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions()); - std::vector output_data(6); - BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_reshape", - &output_data); - EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_reshape", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(1, 2, 3, 4, 5, 6)); } } @@ -1339,7 +1507,7 @@ TEST_F(OpConverterTest, ConvertMatMul) { NodeDef node_def = MakeNodeDef("my_matmul", "MatMul", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Input expects tensor and weights, at my_matmul"); + "MatMul got 0 inputs but expected 2, at my_matmul"); } // Get the NodeDef for MatMul. @@ -1389,12 +1557,13 @@ TEST_F(OpConverterTest, ConvertMatMul) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); - std::vector output_data(2); - BuildAndRun({{"input", {0, 1}}}, "my_matmul", &output_data); + const DataVec input_data{{"input", test::AsTensor({0, 1})}}; + DataVec output_data{{"my_matmul", ConstructTensor(2)}}; + BuildAndRun(input_data, &output_data); if (transpose_b) { - EXPECT_THAT(output_data, ElementsAre(1, 3)); + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(1, 3)); } else { - EXPECT_THAT(output_data, ElementsAre(2, 3)); + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(2, 3)); } } } @@ -1448,23 +1617,28 @@ void TestConvertBiasAdd(OpConverterTest* test) { const int num_input = TrtDimsNumElements(GetTestDims(dims_array)); ASSERT_EQ(trt_input_rank > 1 ? 6 : (data_format == "NHWC" ? 3 : 2), num_input); - std::vector output_data(num_input); - test->BuildAndRun( - {{"input", std::vector(num_input, CType(0))}}, "my_biasadd", - &output_data); + + const DataVec input_data{ + {"input", ConstructTensor(num_input, CType(0))}}; + DataVec output_data{{"my_biasadd", ConstructTensor(num_input)}}; + test->BuildAndRun(input_data, &output_data); if (trt_input_rank == 1) { if (data_format == "NHWC") { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(2), CType(3))); } else { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(2))); } } else { if (data_format == "NHWC") { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3), - CType(1), CType(2), CType(3))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(2), CType(3), CType(1), + CType(2), CType(3))); } else { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(1), CType(1), - CType(2), CType(2), CType(2))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(1), CType(1), CType(2), + CType(2), CType(2))); } } } @@ -1477,7 +1651,7 @@ TEST_F(OpConverterTest, ConvertBiasAdd) { NodeDef node_def = MakeNodeDef("my_biasadd", "BiasAdd", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Input expects tensor and weights, at my_biasadd"); + "BiasAdd got 0 inputs but expected 2, at my_biasadd"); } // OK. Note that kINT32 is not supported by IScaleLayer, so we don't test @@ -1542,21 +1716,25 @@ void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions()); - std::vector output_data(2); - test->BuildAndRun( - {{"input", - /*input_data=*/swap_inputs ? operand2 : operand1}}, - "my_binary", &output_data); + const DataVec input_data{ + {"input", test::AsTensor(swap_inputs ? operand2 : operand1)}}; + DataVec output_data{{"my_binary", ConstructTensor(2)}}; + test->BuildAndRun(input_data, &output_data); if (node_def.op() == "Add") { - EXPECT_THAT(output_data, ElementsAre(CType(5), CType(10.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(5), CType(10.5))); } else if (node_def.op() == "Sub") { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(4.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(4.5))); } else if (node_def.op() == "Mul") { - EXPECT_THAT(output_data, ElementsAre(CType(6), CType(22.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(6), CType(22.5))); } else if (node_def.op() == "Div") { - EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1.5), CType(2.5))); } else if (node_def.op() == "RealDiv") { - EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1.5), CType(2.5))); } else { ASSERT_TRUE(false); } @@ -1591,13 +1769,14 @@ void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); - std::vector output_data(4); - test->BuildAndRun({{"input", input}}, "my_binary", &output_data); + const DataVec input_data{{"input", test::AsTensor(input)}}; + DataVec output_data{{"my_binary", ConstructTensor(4)}}; + test->BuildAndRun(input_data, &output_data); if (weights_dims.size() == 1) { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(11), CType(22), CType(13), CType(24))); } else { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(11), CType(12), CType(23), CType(24))); } } @@ -1625,9 +1804,10 @@ void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); - std::vector output_data(4); - test->BuildAndRun({{"input", input}}, "my_binary", &output_data); - EXPECT_THAT(output_data, + const DataVec input_data{{"input", test::AsTensor(input)}}; + DataVec output_data{{"my_binary", ConstructTensor(4)}}; + test->BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(11), CType(12), CType(13), CType(14))); } @@ -1675,17 +1855,19 @@ void TestBinaryTensorOpWeightFallback(OpConverterTest* test, // Check the result of running the engine. const int expected_num_outputs = TrtDimsNumElements(GetTestDims(expected_output_dims)); - std::vector output_data(expected_num_outputs); - test->BuildAndRun( - {{"input", - /*input_data=*/std::vector(num_inputs, CType(2))}}, - "my_binary", &output_data); + const DataVec input_data{ + {"input", ConstructTensor(num_inputs, CType(2))}}; + DataVec output_data{ + {"my_binary", ConstructTensor(expected_num_outputs)}}; + test->BuildAndRun(input_data, &output_data); if (node_def.op() == "Add") { - EXPECT_THAT(output_data, ElementsAreArray(std::vector( - expected_num_outputs, CType(3)))); + EXPECT_THAT( + GetSpanForData(output_data[0]), + ElementsAreArray(std::vector(expected_num_outputs, CType(3)))); } else if (node_def.op() == "Minimum") { - EXPECT_THAT(output_data, ElementsAreArray(std::vector( - expected_num_outputs, CType(1)))); + EXPECT_THAT( + GetSpanForData(output_data[0]), + ElementsAreArray(std::vector(expected_num_outputs, CType(1)))); } else { ASSERT_TRUE(false); } @@ -1712,32 +1894,33 @@ void TestBinaryTensorOpTensor(OpConverterTest* test) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); - std::vector output_data(4); + const DataVec input_data{ + {"input1", test::AsTensor({CType(3), CType(6)})}, + {"input2", test::AsTensor({CType(2), CType(3)})}}; + DataVec output_data{{"my_binary", ConstructTensor(4)}}; // After broadcasting first input becomes {3, 6, 3, 6} and second input // becomes {2, 3, 2, 3}. - test->BuildAndRun( - {{"input1", {CType(3), CType(6)}}, {"input2", {CType(2), CType(3)}}}, - "my_binary", &output_data); + test->BuildAndRun(input_data, &output_data); if (node_def.op() == "Add") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(5), CType(8), CType(6), CType(9))); } else if (node_def.op() == "Sub") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(1), CType(4), CType(0), CType(3))); } else if (node_def.op() == "Mul") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(6), CType(12), CType(9), CType(18))); } else if (node_def.op() == "Div") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); } else if (node_def.op() == "RealDiv") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); } else if (node_def.op() == "Minimum") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(2), CType(2), CType(3), CType(3))); } else if (node_def.op() == "Maximum") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(3), CType(6), CType(3), CType(6))); } else { ASSERT_TRUE(false); @@ -1751,7 +1934,9 @@ TEST_F(OpConverterTest, ConvertBinary) { NodeDef node_def = MakeNodeDef("my_add", "Add", {num_inputs, "input"}); AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT); RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Binary ops require two inputs, at my_add"); + StrCat("Add got ", std::to_string(num_inputs), + " inputs but expected 2, at my_add") + .c_str()); } { // Both inputs are weights. @@ -1821,14 +2006,18 @@ TEST_F(OpConverterTest, ConvertBinary) { } TEST_F(OpConverterTest, ConvertQuantize) { - for (const string& op : - {"FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars", - "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3"}) { + const std::pair op_with_num_inputs[4] = { + {"FakeQuantWithMinMaxArgs", 1}, + {"FakeQuantWithMinMaxVars", 3}, + {"QuantizeAndDequantizeV2", 3}, + {"QuantizeAndDequantizeV3", 4}}; + for (const auto& pair : op_with_num_inputs) { // Input list is empty, should fail. - NodeDef node_def = MakeNodeDef("my_quantize", op, {}); + NodeDef node_def = MakeNodeDef("my_quantize", pair.first, {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - StrCat("Invalid number of inputs for ", op, ", at my_quantize") + StrCat(pair.first, " got 0 inputs but expected ", + std::to_string(pair.second), ", at my_quantize") .c_str()); } { @@ -1915,9 +2104,9 @@ TEST_F(OpConverterTest, ConvertQuantize) { AddTestTensor("weights_min", {1}); AddTestTensor("weights_max", {1}); RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Min and max inputs for QuantizeAndDequantizeV2 must be weights not " - "tensors, at my_quantize"); + node_def, error::UNIMPLEMENTED, + "The input \"input_min\" for QuantizeAndDequantizeV2 must be a constant" + ", at my_quantize"); } { // QuantizeAndDequantizeV3 ranges set via inputs, ok. @@ -1944,46 +2133,6 @@ TEST_F(OpConverterTest, ConvertQuantize) { } } -TEST_F(OpConverterTest, ConvertRelu6) { - { - // Input list is empty, should fail. - NodeDef node_def = MakeNodeDef("my_relu6", "Relu6", {}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Invalid number of inputs for Relu6, at my_relu6"); - } - - // Get the NodeDef for Relu6. - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - auto relu6 = ops::Relu6(s.WithOpName("my_relu6"), input); - const NodeDef node_def = relu6.operation.node()->def(); - { - // Input is weights, should fail. - Reset(); - AddTestWeights("input", {1}, {1.0f}); - RunValidationAndConversion( - node_def, error::UNIMPLEMENTED, - "Relu6 is only implemented for tensors, not weights, at my_relu6"); - } - { - // Clip tensor values and set quantization ranges, ok. - Reset(); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_relu6", &output)); - EXPECT_TRUE(output.is_tensor()); - auto ranges = quantization_ranges(); - EXPECT_EQ(ranges[output.tensor()], 6.0f); - - std::vector output_data(6); - BuildAndRun({{"input", {-100, -1, 0, 3, 5, 9}}}, "my_relu6", - &output_data); - EXPECT_THAT(output_data, ElementsAre(0, 0, 0, 3, 5, 6)); - } -} - template void TestConvertSquare(OpConverterTest* test) { test->Reset(); @@ -2002,24 +2151,26 @@ void TestConvertSquare(OpConverterTest* test) { ExpectTrtDimsEqualsArray({1, 20}, output.tensor()->getDimensions()); const int num_inputs = 20; - std::vector input_data(num_inputs); - std::vector expected_output_data(num_inputs); + std::vector inputs(num_inputs); + std::vector expected_outputs(num_inputs); for (int i = 0; i < 20; i++) { const CType value = CType(i - 9); - input_data[i] = value; - expected_output_data[i] = value * value; + inputs[i] = value; + expected_outputs[i] = value * value; } - std::vector output_data(num_inputs); - test->BuildAndRun({{"input", input_data}}, "my_square", &output_data); - ExpectArrayNear(expected_output_data, output_data); + const DataVec input_data{{"input", test::AsTensor(inputs)}}; + DataVec output_data{{"my_square", ConstructTensor(num_inputs)}}; + test->BuildAndRun(input_data, &output_data); + ExpectArrayNear(expected_outputs, GetSpanForData(output_data[0])); } TEST_F(OpConverterTest, ConvertSquare) { { // Input list is empty, should fail. NodeDef node_def = MakeNodeDef("my_square", "Square", {}); - RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Square expects one input, at my_square"); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Square got 0 inputs but expected 1, at my_square"); } { // Input is weights, should fail. @@ -2031,7 +2182,7 @@ TEST_F(OpConverterTest, ConvertSquare) { AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "Square is only implemented for tensors, at my_square"); + "The input \"x\" for Square must be a tensor, at my_square"); } // OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't @@ -2047,7 +2198,7 @@ TEST_F(OpConverterTest, ConvertActivation) { // Input list is empty, should fail. NodeDef node_def = MakeNodeDef("my_act", "Relu", {}); RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Relu expects one input, at my_act"); + "Relu got 0 inputs but expected 1, at my_act"); } { // Input is weights, should fail. @@ -2059,16 +2210,26 @@ TEST_F(OpConverterTest, ConvertActivation) { AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "Relu is only implemented for tensors, at my_act"); + "The input \"input\" for Relu must be a tensor, at my_act"); } + constexpr float kAlpha = 0.2f; + // Get nodedef for activation layer. auto get_act_nodedef = [](string op_name) -> NodeDef { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - if (op_name == "Relu") { + if (op_name == "LeakyRelu") { + // LeakyRelu does not have a C++ API + NodeDef node_def = MakeNodeDef("my_act", "LeakyRelu", {"input"}); + (*node_def.mutable_attr())["alpha"].set_f(kAlpha); + return node_def; + } else if (op_name == "Relu") { auto act = ops::Relu(s.WithOpName("my_act"), input); return act.operation.node()->def(); + } else if (op_name == "Relu6") { + auto act = ops::Relu6(s.WithOpName("my_act"), input); + return act.operation.node()->def(); } else if (op_name == "Sigmoid") { auto act = ops::Sigmoid(s.WithOpName("my_act"), input); return act.operation.node()->def(); @@ -2081,8 +2242,12 @@ TEST_F(OpConverterTest, ConvertActivation) { }; // Get expected output for activation layer. auto get_act_output = [](string op_name, float input) -> float { - if (op_name == "Relu") { + if (op_name == "LeakyRelu") { + return (input > 0.0f) ? input : input * kAlpha; + } else if (op_name == "Relu") { return (input > 0.0f) ? input : 0.0f; + } else if (op_name == "Relu6") { + return std::min(std::max(input, 0.0f), 6.0f); } else if (op_name == "Sigmoid") { return 1.0f / (1.0f + std::exp(-input)); } else if (op_name == "Tanh") { @@ -2093,7 +2258,8 @@ TEST_F(OpConverterTest, ConvertActivation) { }; // Ok. - for (string op_name : {"Relu", "Sigmoid", "Tanh"}) { + for (const string& op_name : + {"LeakyRelu", "Relu", "Relu6", "Sigmoid", "Tanh"}) { Reset(); NodeDef node_def = get_act_nodedef(op_name); AddTestTensor("input", {1, 2, 3}); @@ -2102,13 +2268,20 @@ TEST_F(OpConverterTest, ConvertActivation) { TF_EXPECT_OK(GetTensorOrWeights("my_act", &output)); EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); + if (op_name == "Relu6") { + // Relu6 should set quantization range automatically. + auto ranges = quantization_ranges(); + EXPECT_EQ(ranges[output.tensor()], 6.0f); + } - const std::vector input_data = {-100, -2, -1, 0, 1, 100}; - std::vector output_data(6); - BuildAndRun({{"input", input_data}}, "my_act", &output_data); - for (int i = 0; i < input_data.size(); i++) { - const float expected_output = get_act_output(op_name, input_data[i]); - EXPECT_FLOAT_EQ(output_data[i], expected_output); + const std::vector input = {-100, -2, -1, 0, 1, 100}; + const DataVec input_data{{"input", test::AsTensor(input)}}; + DataVec output_data{{"my_act", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + for (int i = 0; i < input.size(); i++) { + const float expected_output = get_act_output(op_name, input[i]); + EXPECT_FLOAT_EQ(GetSpanForData(output_data[0])[i], + expected_output); } } } @@ -2119,7 +2292,7 @@ TEST_F(OpConverterTest, ConvertExpandDims) { NodeDef node_def = MakeNodeDef("my_expanddims", "ExpandDims", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Two inputs expected for ExpandDims, at my_expanddims"); + "ExpandDims got 0 inputs but expected 2, at my_expanddims"); } // Get the NodeDef for ExpandDims. @@ -2129,24 +2302,23 @@ TEST_F(OpConverterTest, ConvertExpandDims) { auto expanddims = ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights); const NodeDef& node_def = expanddims.operation.node()->def(); - { // Input is weights, should fail. Reset(); AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); AddTestWeights("weights", {1}, {1}); - RunValidationAndConversion( - node_def, error::UNIMPLEMENTED, - "ExpandDims expects tensor for input, at my_expanddims"); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"input\" for ExpandDims must be a " + "tensor, at my_expanddims"); } { // Axis is a tensor, should fail. Reset(); AddTestTensor("input", {1, 2, 3}); AddTestTensor("weights", {3}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "ExpandDims expects weights for axis, at my_expanddims"); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"axis\" for ExpandDims must be a " + "constant, at my_expanddims"); } { // Add dim at batch dimension, should fail. @@ -2193,11 +2365,6 @@ TEST_F(OpConverterTest, ConvertExpandDims) { } struct TestParams { - TestParams(const std::vector& input_dims, int axis, - const std::vector& expected_output_dims) - : input_dims(input_dims), - axis(axis), - expected_output_dims(expected_output_dims) {} std::vector input_dims; int axis; std::vector expected_output_dims; @@ -2222,10 +2389,12 @@ TEST_F(OpConverterTest, ConvertExpandDims) { ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); - std::vector output_data(6); - BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_expanddims", - &output_data); - EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_expanddims", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(1, 2, 3, 4, 5, 6)); } } @@ -2233,8 +2402,9 @@ TEST_F(OpConverterTest, ConvertSqueeze) { { // Input list is empty, should fail. NodeDef node_def = MakeNodeDef("my_squeeze", "Squeeze", {}); - RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "One input expected for Squeeze, at my_squeeze"); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Squeeze got 0 inputs but expected 1, at my_squeeze"); } { // No attrs, should fail. @@ -2254,7 +2424,7 @@ TEST_F(OpConverterTest, ConvertSqueeze) { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); ops::Squeeze::Attrs squeeze_attrs; - squeeze_attrs.axis_ = gtl::ArraySlice(axis); + squeeze_attrs.axis_ = gtl::ArraySlice(axis); // non-absl ok auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs); return squeeze.operation.node()->def(); @@ -2267,7 +2437,7 @@ TEST_F(OpConverterTest, ConvertSqueeze) { AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "Squeeze expects tensor for input, at my_squeeze"); + "The input \"input\" for Squeeze must be a tensor, at my_squeeze"); } { // Squeeze batch dim, should fail. @@ -2307,11 +2477,6 @@ TEST_F(OpConverterTest, ConvertSqueeze) { } struct TestParams { - TestParams(const std::vector& input_dims, const std::vector& axis, - const std::vector& expected_output_dims) - : input_dims(input_dims), - axis(axis), - expected_output_dims(expected_output_dims) {} std::vector input_dims; std::vector axis; std::vector expected_output_dims; @@ -2342,10 +2507,1117 @@ TEST_F(OpConverterTest, ConvertSqueeze) { ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); - std::vector output_data(6); - BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_squeeze", - &output_data); - EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_squeeze", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(1, 2, 3, 4, 5, 6)); + } +} + +TEST_F(OpConverterTest, ConvertStridedSlice) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_strided_slice", "StridedSlice", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "StridedSlice got 0 inputs but expected 4, at my_strided_slice"); + } + + // Get nodedef for StridedSlice layer. + auto get_strided_slice_nodedef = + [](int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, + int new_axis_mask = 0, int shrink_axis_mask = 0) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32); + auto end = ops::Placeholder(s.WithOpName("end"), DT_INT32); + auto strides = ops::Placeholder(s.WithOpName("strides"), DT_INT32); + ops::StridedSlice::Attrs attrs = ops::StridedSlice::Attrs() + .BeginMask(begin_mask) + .EndMask(end_mask) + .EllipsisMask(ellipsis_mask) + .NewAxisMask(new_axis_mask) + .ShrinkAxisMask(shrink_axis_mask); + auto strided_slice = ops::StridedSlice(s.WithOpName("my_strided_slice"), + input, begin, end, strides, attrs); + return strided_slice.operation.node()->def(); + }; + + { + // Input is weights, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"input\" for StridedSlice must be a " + "tensor, at my_strided_slice"); + } + { + // Begin, end, strides are tensors, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("begin", {4}); + AddTestTensor("end", {4}); + AddTestTensor("strides", {4}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"begin\" for StridedSlice must be a constant, at " + "my_strided_slice"); + } + { + // Non-zero ellipsis_mask, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef( + /*begin_mask=*/0, /*end_mask=*/0, /*ellipsis_mask=*/2, + /*new_axis_mask=*/0, /*shrink_axis_mask=*/0); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "ellipsis_mask is not supported for StridedSlice, at " + "my_strided_slice"); + } + { + // Modify batch dim, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {0, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "TensorRT does not allow modifications to the batch dimension, at " + "my_strided_slice"); + } + { + // Dynamic batch size without end_mask, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "TensorRT does not allow modifications to the batch dimension, at " + "my_strided_slice"); + } + { + // Dynamic batch size but using end_mask, ok. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(/*begin_mask=*/0, + /*end_mask=*/1); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {0, 1, 2, 2}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion(node_def); + } +// TRT 5.1+ supports strides +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + { + // Negative strides, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, -1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Negative or zero stride values are not " + "supported for StridedSlice, at " + "my_strided_slice"); + } +#else + { + // Stride is not 1, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 2, 1, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Strides other than 1 are not supported with " + "this version of TRT, at my_strided_slice"); + } +#endif + { + // Size of sliced dim is negative, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 2, 0}); + AddTestWeights("end", {4}, {1, 1, 0, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "\"size\" cannot be negative or zero for " + "StridedSlice, at my_strided_slice"); + } + + struct TestParams { + std::vector input_dims; + std::vector begin; + std::vector end; + std::vector strides; + int begin_mask; + int end_mask; + std::vector expected_output_dims; + std::vector expected_output; + }; + + auto get_mask = [](const std::vector& mask) { + int result = 0; + for (int i = 0; i < mask.size(); i++) { + if (mask[i]) result += (1 << i); + } + return result; + }; + + // Same input is used for all tests. + const std::vector ok_input = {1, 2, 3, 4, 5, 6}; + +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + const int kStridedSliceOKCases = 23; +#else + const int kStridedSliceOKCases = 19; +#endif + // Ok. + TestParams ok_params[kStridedSliceOKCases] = { + // 2D Crop. + TestParams{/*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, 0, 0}, + /*end=*/{0, 0, 1, 2}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 0, 0}), + /*expected_output_dims=*/{1, 1, 2}, /*expected_output=*/{1, 2}}, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}}, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with transpose. + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{1, 2}}, + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{5, 6}}, + TestParams{ + /*input_dims=*/{2, 1, 3}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{1, 2}}, + TestParams{ + /*input_dims=*/{2, 1, 3}, + /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with reshape. + TestParams{/*input_dims=*/{2, 3}, + /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0}), + /*expected_output_dims=*/{1, 2}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 3}, + /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1}), + /*expected_output_dims=*/{1, 2}, + /*expected_output=*/{5, 6}}, + // 1D Crop. + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 0}), /*expected_output_dims=*/{1, 2, 2}, + /*expected_output=*/{1, 2, 4, 5}}, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 1, 3}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with transpose. + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1, 1}), /*expected_output_dims=*/{1, 3, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 3, 1}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with reshape. + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 0}, /*end=*/{0, 3}, /*strides=*/{1, 1}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{1, 6}, + /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 0}), + /*expected_output_dims=*/{1, 3}, + /*expected_output=*/{3, 4, 5}}, + TestParams{/*input_dims=*/{6, 1}, + /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*expected_output_dims=*/{3, 1}, + /*expected_output=*/{3, 4, 5}}, + // Negative axis. + TestParams{/*input_dims=*/{6, 1}, + /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*expected_output_dims=*/{3, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{6, 1}, + /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*expected_output_dims=*/{5, 1}, + /*expected_output=*/{1, 2, 3, 4, 5}}, + // Clamp out of bounds begin and end. + TestParams{/*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, -9999, -9}, + /*end=*/{0, 1, 1000, 4}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/{1, 2, 3, 4, 5, 6}}, +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + // Strides + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 0}, /*end=*/{0, 5}, /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 3, 5}}, + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 0}, /*end=*/{0, 6}, /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 3, 5}}, + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 1}, /*end=*/{0, 6}, /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{2, 4, 6}}, + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 2}, /*end=*/{0, 6}, /*strides=*/{1, 3}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{2}, + /*expected_output=*/{3, 6}}, +#endif + }; + + for (int i = 0; i < kStridedSliceOKCases; i++) { + Reset(); + NodeDef node_def = get_strided_slice_nodedef(ok_params[i].begin_mask, + ok_params[i].end_mask); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("begin", + {static_cast(ok_params[i].begin.size())}, + ok_params[i].begin); + AddTestWeights("end", {static_cast(ok_params[i].end.size())}, + ok_params[i].end); + AddTestWeights("strides", + {static_cast(ok_params[i].strides.size())}, + ok_params[i].strides); + RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + const DataVec input_data{{"input", test::AsTensor(ok_input)}}; + DataVec output_data{ + {"my_strided_slice", + ConstructTensor(ok_params[i].expected_output.size())}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(ok_params[i].expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertSlice) { + // Get nodedef for Slice layer. + auto get_slice_nodedef = []() -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32); + auto size = ops::Placeholder(s.WithOpName("size"), DT_INT32); + auto slice = ops::Slice(s.WithOpName("my_slice"), input, begin, size); + return slice.operation.node()->def(); + }; + + { + // Begin is below bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, -1, 0}); + AddTestWeights("size", {4}, {1, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" for dimension 2 in Slice is out of range, at my_slice"); + } + { + // Begin is above bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 3, 0}); + AddTestWeights("size", {4}, {1, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" for dimension 2 in Slice is out of range, at my_slice"); + } + { + // Size is below bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {1, 1, 2, -2}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" + \"size\" for dimension 3 in Slice is out of range, at " + "my_slice"); + } + { + // Size is above bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {1, 1, 3, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" + \"size\" for dimension 2 in Slice is out of range, at " + "my_slice"); + } + { + // Modify batch dim, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {0, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "TensorRT does not allow modifications to the batch dimension, at " + "my_slice"); + } + { + // Dynamic batch size with size[0] not -1, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {1, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "TensorRT does not allow modifications to the batch dimension, at " + "my_slice"); + } + { + // Dynamic batch size but using size[0] of -1, ok. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {-1, 1, 2, 2}); + RunValidationAndConversion(node_def); + } + + struct TestParams { + std::vector input_dims; + std::vector begin; + std::vector size; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Ok. + const int kSliceOKCases = 5; + TestParams ok_params[kSliceOKCases] = { + TestParams{{1, 2, 3}, + {0, 0, 0, 0}, + {-1, -1, -1, -1}, + {1, 2, 3}, + {1, 2, 3, 4, 5, 6}}, + TestParams{ + {1, 2, 3}, {0, 0, 0, 0}, {1, 1, 2, 3}, {1, 2, 3}, {1, 2, 3, 4, 5, 6}}, + TestParams{ + {1, 2, 3}, {0, 0, 0, 0}, {1, -1, 2, 2}, {1, 2, 2}, {1, 2, 4, 5}}, + TestParams{{6}, {0, 1}, {1, 5}, {5}, {2, 3, 4, 5, 6}}, + TestParams{{6}, {0, 1}, {-1, 3}, {3}, {2, 3, 4}}, + }; + + for (int i = 0; i < kSliceOKCases; i++) { + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("begin", + {static_cast(ok_params[i].begin.size())}, + ok_params[i].begin); + AddTestWeights("size", {static_cast(ok_params[i].size.size())}, + ok_params[i].size); + RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_slice", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_slice", ConstructTensor( + ok_params[i].expected_output.size())}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(ok_params[i].expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertConv2D) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_conv2d", "Conv2D", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Conv2D got 0 inputs but expected 2, at my_conv2d"); + } + + // Get nodedef for Conv2D layer. + auto get_conv2d_nodedef = + [](std::vector strides = {1, 1, 1, 1}, string padding = "SAME", + string data_format = "NCHW", std::vector dilations = {1, 1, 1, 1}, + bool is_conv2d_backprop_input = false) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT); + if (is_conv2d_backprop_input) { + auto input_sizes = + ops::Placeholder(s.WithOpName("input_sizes"), DT_INT32); + ops::Conv2DBackpropInput::Attrs attrs = ops::Conv2DBackpropInput::Attrs() + .DataFormat(data_format) + .Dilations(dilations); + auto conv2d = + ops::Conv2DBackpropInput(s.WithOpName("my_conv2d"), input_sizes, + filter, input, strides, padding, attrs); + return conv2d.operation.node()->def(); + } else { + ops::Conv2D::Attrs attrs = + ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations); + auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, + strides, padding, attrs); + return conv2d.operation.node()->def(); + } + }; + + { + // Input is weights, should fail. + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"input\" for Conv2D must be a tensor, at my_conv2d"); + } + { + // Filter is tensor, should fail. + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("weights", {3, 3, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"filter\" for Conv2D must be a constant, at my_conv2d"); + } + { + // Filter is not 4D, should fail. + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Conv2D expects kernel of dimension 4, at my_conv2d"); + } + { + // Dilations is not 4D, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Convolution dilations field must specify 4 dimensions, at my_conv2d"); + } + { + // Dilation value is not 1 for channel, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 2, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation rate must be 1 for batch and channel " + "dimensions, at my_conv2d"); + } + { + // Dilation value is not 1 for channel (NHWC), should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 1, 2}); + AddTestTensor("input", {2, 3, 1}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation rate must be 1 for batch and channel " + "dimensions, at my_conv2d"); + } + { + // Dilation + Conv2DBackpropInput, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 2, 1}, true); + AddTestTensor("input", {2, 3, 1}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + AddTestWeights("input_sizes", {4}, {1, 2, 3, 1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation with Conv2DBackpropInput " + "(conv2d_transpose) is not supported, " + "at my_conv2d"); + } + { + // Strides is not 4D, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Convolution strides field must specify 4 dimensions, at my_conv2d"); + } + { + // Stride value is not 1 for channel, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 2, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Stride must be 1 for batch and channel dimensions, at my_conv2d"); + } + + struct TestParams { + std::vector input_dims; + std::vector input; + std::vector filter_dims; + std::vector filter; + std::vector strides; + string padding; + string data_format; + std::vector dilations; + bool is_conv2d_backprop_input; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Ok. + const int kConv2DOKCases = 7; + TestParams ok_params[kConv2DOKCases] = { + // Basic + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, + /*expected_output_dims=*/{1, 2, 2}, + /*expected_output=*/{1, 1, 0, 1}}, + // SAME padding (Asymmetric) + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/{1, 1, -2, 0, 1, -4}}, + // SAME padding (Symmetric) + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 3, 1, 1}, + /*filter=*/{-1, 0, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/{1, 2, -1, 3, 1, -3}}, + // NHWC + TestParams{/*input_dims=*/{2, 3, 1}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NHWC", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, + /*expected_output_dims=*/{2, 2, 1}, + /*expected_output=*/{1, 1, 0, 1}}, + // Dilated + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 2}, + /*is_conv2d_backprop_input=*/false, + /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{2, 1}}, + // Strided + TestParams{/*input_dims=*/{1, 2, 4}, + /*input=*/{0, 1, 2, 2, 3, 4, 4, 7}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 2}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, + /*expected_output_dims=*/{1, 2, 2}, + /*expected_output=*/{1, 0, 1, 3}}, + // Transpose Strided + TestParams{/*input_dims=*/{1, 2, 2}, + /*input=*/{0, 1, 2, 3}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 2}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/true, + /*expected_output_dims=*/{1, 2, 4}, + /*expected_output=*/{0, 0, -1, 1, -2, 2, -3, 3}}, + }; + + for (int i = 0; i < kConv2DOKCases; i++) { + Reset(); + NodeDef node_def = get_conv2d_nodedef( + ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format, + ok_params[i].dilations, ok_params[i].is_conv2d_backprop_input); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("weights", ok_params[i].filter_dims, + ok_params[i].filter); + if (ok_params[i].is_conv2d_backprop_input) { + AddTestWeights( + "input_sizes", + {static_cast(ok_params[i].expected_output.size())}, + ok_params[i].expected_output); + } + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + const DataVec input_data{ + {"input", test::AsTensor(ok_params[i].input)}}; + DataVec output_data{ + {"my_conv2d", + ConstructTensor(ok_params[i].expected_output.size())}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(ok_params[i].expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertTopK) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_topk", "TopKV2", {}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Input expects tensor and weights, at my_topk"); + } + + for (const auto dtype : {DT_FLOAT, DT_INT32}) { + // Get the NodeDef for TopKV2. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); + auto topk = ops::TopK(s.WithOpName("my_topk"), input, weights); + const NodeDef& node_def = topk.operation.node()->def(); + { + // K is a tensor, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/1, + /*trt_dtype=*/TfDataTypeToTrt(dtype)); + AddTestTensor("weights", {2}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Input expects tensor and weights, at my_topk"); + } + { + // Ok. + Reset(); + AddTestTensor("input", {1, 2, 5}); + AddTestWeights("weights", {1}, {2}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights outputs[2]; + TF_EXPECT_OK(GetTensorOrWeights("my_topk", &outputs[0])); + TF_EXPECT_OK(GetTensorOrWeights("my_topk:1", &outputs[1])); + for (auto& output : outputs) { + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({1, 2, 2}, output.tensor()->getDimensions()); + } + + const DataVec input_data{ + {"input", test::AsTensor({-9, 3, 5, 1, 6, -5, 7, 1, 0, -1})}}; + DataVec output_data{{"my_topk", ConstructTensor(4)}, + {"my_topk:1", ConstructTensor(4)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(6, 5, 7, 1)); + EXPECT_THAT(GetSpanForData(output_data[1]), + ElementsAre(4, 2, 1, 2)); + } + } +} + +template +void TestConvertGather(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + + // Get the NodeDef for GatherV2. + Scope s = Scope::NewRootScope(); + auto params = ops::Placeholder(s.WithOpName("params"), dtype); + auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); + auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); + auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis); + const NodeDef& node_def = gather.operation.node()->def(); + + struct TestParams { + std::vector params_dims; + std::vector indices_dims; + std::vector indices; + int axis; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Input is the same {1, 2, 3, 4, 5, 6} for all cases. + const int kGatherOKCases = 5; + TestParams ok_params[kGatherOKCases] = { + // Vector indices (output is rank(params)). + TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1}, {1, 4}}, + TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1}, {2, 5}}, + TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1}, {3, 6}}, + TestParams{{1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 3}, {3, 1, 2, 6, 4, 5}}, + // Higher rank indices (output is rank(params) + rank(indices) - 1). + TestParams{{1, 2, 3}, {1, 1}, {0}, 2, {1, 1, 1, 3}, {1, 2, 3}}, + }; + + // Ok. + for (int i = 0; i < kGatherOKCases; i++) { + test->Reset(); + test->AddTestTensor("params", ok_params[i].params_dims, 1, + TfDataTypeToTrt(dtype)); + test->AddTestTensor("indices", ok_params[i].indices_dims, 1, + nvinfer1::DataType::kINT32); + test->AddTestWeights("axis", {1}, {ok_params[i].axis}); + test->RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_gather", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + // Create input in CType and convert expected output to CType. + std::vector inputs = {CType(1), CType(2), CType(3), + CType(4), CType(5), CType(6)}; + std::vector converted_expected_output( + ok_params[i].expected_output.begin(), + ok_params[i].expected_output.end()); + + const DataVec input_data{ + {"params", test::AsTensor(inputs)}, + {"indices", test::AsTensor(ok_params[i].indices)}}; + DataVec output_data{ + {"my_gather", + ConstructTensor(ok_params[i].expected_output.size())}}; + test->BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(converted_expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertGather) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_gather", "GatherV2", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "GatherV2 got 0 inputs but expected 3, at my_gather"); + } + + // Get the NodeDef for GatherV2. + Scope s = Scope::NewRootScope(); + auto params = ops::Placeholder(s.WithOpName("params"), DT_FLOAT); + auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); + auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); + auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis); + const NodeDef& node_def = gather.operation.node()->def(); + { + // Axis is a tensor, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestTensor("axis", {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"axis\" for GatherV2 must be a constant, at my_gather"); + } + { + // Axis is out of bounds, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestWeights("axis", {1}, {4}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Axis value of 4 is out of bounds, must be in " + "range [-4, 4), at my_gather"); + } + { + // Axis is batch dimension, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestWeights("axis", {1}, {0}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "TensorRT does not allow manipulation of the " + "batch dimension, at my_gather"); + } + + Reset(); + TestConvertGather(this); + TestConvertGather(this); + TestConvertGather(this); +} + +TEST_F(OpConverterTest, ConvertUnary) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_unary", "Neg", {}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Neg got 0 inputs but expected 1, at my_unary"); + } + { + // Input is weights, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto neg = ops::Neg(s.WithOpName("my_unary"), input); + const NodeDef& node_def = neg.operation.node()->def(); + AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"x\" for Neg must be a tensor, at my_unary"); + } + + // Get nodedef for unary layer. + auto get_unary_nodedef = [](string op_name) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + if (op_name == "Abs") { + auto unary = ops::Abs(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Acos") { + auto unary = ops::Acos(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Acosh") { + auto unary = ops::Acosh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Asin") { + auto unary = ops::Asin(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Asinh") { + auto unary = ops::Asinh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Atan") { + auto unary = ops::Atan(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Atanh") { + auto unary = ops::Atanh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Ceil") { + auto unary = ops::Ceil(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Cos") { + auto unary = ops::Cos(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Cosh") { + auto unary = ops::Cosh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Exp") { + auto unary = ops::Exp(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Floor") { + auto unary = ops::Floor(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Log") { + auto unary = ops::Log(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Neg") { + auto unary = ops::Neg(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Reciprocal") { + auto unary = ops::Reciprocal(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Rsqrt") { + auto unary = ops::Rsqrt(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Sin") { + auto unary = ops::Sin(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Sinh") { + auto unary = ops::Sinh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Sqrt") { + auto unary = ops::Sqrt(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Tan") { + auto unary = ops::Tan(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } + EXPECT_TRUE(false); + return NodeDef(); + }; + // Get expected output for unary layer. + auto get_unary_output = [](string op_name, float input) -> float { + if (op_name == "Abs") { + return std::abs(input); + } else if (op_name == "Acos") { + return std::acos(input); + } else if (op_name == "Acosh") { + return std::acosh(input); + } else if (op_name == "Asin") { + return std::asin(input); + } else if (op_name == "Asinh") { + return std::asinh(input); + } else if (op_name == "Atan") { + return std::atan(input); + } else if (op_name == "Atanh") { + return std::atanh(input); + } else if (op_name == "Ceil") { + return std::ceil(input); + } else if (op_name == "Cos") { + return std::cos(input); + } else if (op_name == "Cosh") { + return std::cosh(input); + } else if (op_name == "Exp") { + return std::exp(input); + } else if (op_name == "Floor") { + return std::floor(input); + } else if (op_name == "Log") { + return std::log(input); + } else if (op_name == "Neg") { + return -input; + } else if (op_name == "Reciprocal") { + return 1.0 / input; + } else if (op_name == "Rsqrt") { + return 1.0 / std::sqrt(input); + } else if (op_name == "Sin") { + return std::sin(input); + } else if (op_name == "Sinh") { + return std::sinh(input); + } else if (op_name == "Sqrt") { + return std::sqrt(input); + } else if (op_name == "Tan") { + return std::tan(input); + } + EXPECT_TRUE(false); + return 0; + }; + + // Get list of ops to test. + std::vector ops_to_test; + // Add all ops supported by ConvertUnary. + auto* map = UnaryOperationMap(); + ops_to_test.reserve(map->size()); + for (auto& pair : *map) { + ops_to_test.push_back(pair.first); + } + // Add other unary ops to test. + ops_to_test.push_back("Rsqrt"); + // Ok. + for (string op_name : ops_to_test) { + Reset(); + NodeDef node_def = get_unary_nodedef(op_name); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_unary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); + + const std::vector input = {-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f}; + const DataVec input_data{{"input", test::AsTensor(input)}}; + DataVec output_data{{"my_unary", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + for (int i = 0; i < input.size(); ++i) { + const float expected_output = get_unary_output(op_name, input[i]); + EXPECT_THAT(GetSpanForData(output_data[0])[i], + NanSensitiveFloatNear(expected_output, 0.0001)); + } } } diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc similarity index 92% rename from tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc rename to tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index c1688d4db88a270dcd202989f89a677ed10576d9..0eedfcacb4c11c8dc63fcfc13f044586b99b3c76 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -12,9 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h" -#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" -#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h" + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" @@ -30,9 +32,9 @@ namespace tensorflow { namespace tensorrt { namespace convert { // TODO(sami): Remove VLOG messages once the code matures +using absl::StrAppend; +using absl::StrCat; using tensorflow::str_util::Uppercase; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; tensorflow::Status TRTOptimizationPass::Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) { @@ -64,7 +66,7 @@ tensorflow::Status TRTOptimizationPass::Init( max_workspace_size_bytes_ = params.at("max_workspace_size_bytes").i(); } if (params.count("precision_mode")) { - TF_RETURN_IF_ERROR(GetPrecisionMode( + TF_RETURN_IF_ERROR(TrtPrecisionModeFromName( Uppercase(params.at("precision_mode").s()), &precision_mode_)); } if (params.count("use_calibration")) { @@ -85,7 +87,7 @@ void TRTOptimizationPass::PrintDebugInfo( LOG(INFO) << offset << "type = " << cluster->type(); LOG(INFO) << offset << "num warmup steps = " << cluster->NumWarmupSteps(); const auto dev_names = cluster->GetDeviceNames(); - if (dev_names.size()) { + if (!dev_names.empty()) { LOG(INFO) << offset << " Device names:"; for (const auto s : dev_names) { LOG(INFO) << offset2 << s; @@ -101,7 +103,7 @@ void TRTOptimizationPass::PrintDebugInfo( } const auto dev_props = cluster->GetDevices(); - if (dev_props.size()) { + if (!dev_props.empty()) { LOG(INFO) << offset << "Device properties:"; for (auto k : dev_props) { LOG(INFO) << offset2 << k.first; @@ -129,7 +131,7 @@ void TRTOptimizationPass::PrintDebugInfo( } } LOG(INFO) << "item: " << item.id; - if (item.feed.size()) { + if (!item.feed.empty()) { LOG(INFO) << offset << "Feeds :"; for (const auto& f : item.feed) { const auto& shape = f.second.shape(); @@ -138,7 +140,7 @@ void TRTOptimizationPass::PrintDebugInfo( } else { LOG(INFO) << offset << "No Feeds"; } - if (item.fetch.size()) { + if (!item.fetch.empty()) { LOG(INFO) << offset << "Fetches :"; for (const auto& f : item.fetch) { LOG(INFO) << offset2 << f; @@ -147,7 +149,7 @@ void TRTOptimizationPass::PrintDebugInfo( LOG(INFO) << offset << "No Fetches"; } - if (item.init_ops.size()) { + if (!item.init_ops.empty()) { LOG(INFO) << offset << "init ops :"; for (const auto& f : item.init_ops) { LOG(INFO) << offset2 << f; @@ -158,7 +160,7 @@ void TRTOptimizationPass::PrintDebugInfo( LOG(INFO) << "Save Op = " << item.save_op; LOG(INFO) << "Restore Op = " << item.restore_op; LOG(INFO) << "save_restore_loc_tensor = " << item.save_restore_loc_tensor; - if (item.keep_ops.size()) { + if (!item.keep_ops.empty()) { LOG(INFO) << offset << "keep ops :"; for (const auto& f : item.keep_ops) { LOG(INFO) << offset2 << f; @@ -195,7 +197,7 @@ tensorflow::Status TRTOptimizationPass::Optimize( PrintDebugInfo(cluster, item); } int max_dim = -1; - if (item.feed.size()) { + if (!item.feed.empty()) { for (const auto& f : item.feed) { const auto& shape = f.second.shape(); if (shape.dims() > 0) { @@ -225,9 +227,10 @@ tensorflow::Status TRTOptimizationPass::Optimize( TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); tensorflow::tensorrt::convert::ConversionParams cp; - if (use_calibration_ && precision_mode_ != INT8MODE) { - LOG(ERROR) << "Calibration with FP32 or FP16 is not implemented. " - << "Falling back to use_calibration = False."; + if (use_calibration_ && precision_mode_ != TrtPrecisionMode::INT8) { + VLOG(1) << "Calibration with FP32 or FP16 is not implemented. " + << "Falling back to use_calibration = False." + << "Note that the default value of use_calibration is True."; use_calibration_ = false; } @@ -242,7 +245,7 @@ tensorflow::Status TRTOptimizationPass::Optimize( // If the last token is not an integer, it must be part of the name. // Otherwise it is port number. if (tokens.size() > 1 && - !strings::safe_strto32(tokens.back(), &dumm_port)) { + !strings::safe_strto32(tokens.back(), &dumm_port)) { // non-absl ok StrAppend(&s, ":", tokens.back()); } nodes_to_preserve.push_back(s); diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h similarity index 87% rename from tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h rename to tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h index 3e8dc0978e43e2e9ba07aaa09f74acfe8e59b9a7..b2aed2a37afb6c01863f5617bad0bafe004eec24 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ #include +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" #include "tensorflow/core/platform/logging.h" @@ -34,7 +35,7 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { TRTOptimizationPass(const string& name = "TRTOptimizationPass") : name_(name), minimum_segment_size_(3), - precision_mode_(0), + precision_mode_(TrtPrecisionMode::FP32), maximum_batch_size_(-1), is_dynamic_op_(false), max_cached_batches_(1), @@ -62,7 +63,7 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { private: const string name_; int minimum_segment_size_; - int precision_mode_; + TrtPrecisionMode precision_mode_; int maximum_batch_size_; bool is_dynamic_op_; std::vector batches_; @@ -77,4 +78,4 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { #endif // GOOGLE_CUDA #endif // GOOGLE_TENSORRT -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ diff --git a/tensorflow/contrib/tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc similarity index 73% rename from tensorflow/contrib/tensorrt/convert/utils.cc rename to tensorflow/compiler/tf2tensorrt/convert/utils.cc index e7a1febb8c076891596741fe30721e7acca15a73..0ca3a5a4a58e6a3e29d3d515f496b8cb5e9f7eb0 100644 --- a/tensorflow/contrib/tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -34,33 +34,32 @@ bool IsGoogleTensorRTEnabled() { #endif } -Status GetPrecisionModeName(const int precision_mode, string* name) { - switch (precision_mode) { - case FP32MODE: +Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name) { + switch (mode) { + case TrtPrecisionMode::FP32: *name = "FP32"; break; - case FP16MODE: + case TrtPrecisionMode::FP16: *name = "FP16"; break; - case INT8MODE: + case TrtPrecisionMode::INT8: *name = "INT8"; break; default: - return tensorflow::errors::OutOfRange("Unknown precision mode"); + return errors::OutOfRange("Unknown precision mode"); } return Status::OK(); } -Status GetPrecisionMode(const string& name, int* precision_mode) { +Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) { if (name == "FP32") { - *precision_mode = FP32MODE; + *mode = TrtPrecisionMode::FP32; } else if (name == "FP16") { - *precision_mode = FP16MODE; + *mode = TrtPrecisionMode::FP16; } else if (name == "INT8") { - *precision_mode = INT8MODE; + *mode = TrtPrecisionMode::INT8; } else { - return tensorflow::errors::InvalidArgument("Invalid precision mode name: ", - name); + return errors::InvalidArgument("Invalid precision mode name: ", name); } return Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h similarity index 72% rename from tensorflow/contrib/tensorrt/convert/utils.h rename to tensorflow/compiler/tf2tensorrt/convert/utils.h index 0592f31462af2b20f3a13fe5119e89c2ba42dd8a..0aa602dda2f3e98095bf72b5810a246c690d6741 100644 --- a/tensorflow/contrib/tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ #include @@ -35,16 +35,13 @@ using TrtUniquePtrType = std::unique_ptr>; bool IsGoogleTensorRTEnabled(); -// TODO(aaroey): use an enum instead. -const int FP32MODE = 0; -const int FP16MODE = 1; -const int INT8MODE = 2; +enum class TrtPrecisionMode { FP32, FP16, INT8 }; -Status GetPrecisionModeName(const int precision_mode, string* name); +Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name); -Status GetPrecisionMode(const string& name, int* precision_mode); +Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode); } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..81406b6e301ca350a3e52c97f5fcb575e88c3a90 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/refcount.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +class GetSerializedResourceOp : public OpKernel { + public: + explicit GetSerializedResourceOp(OpKernelConstruction* context) + : OpKernel(context) {} + + ~GetSerializedResourceOp() override {} + + void Compute(OpKernelContext* context) override { + // TODO(laigd): it will allocate the tensor on the device and copy the + // serialized string to that tensor, and later sess.run() will copy it back + // to host. We need to optimize this. + const string& container = context->input(0).scalar()(); + const string& resource_name = context->input(1).scalar()(); + + // Get the resource. + SerializableResourceBase* resource = nullptr; + OP_REQUIRES_OK(context, context->resource_manager()->Lookup( + container, resource_name, &resource)); + ::tensorflow::core::ScopedUnref sc(resource); + + // Serialize the resource as output. + string serialized_resource; + OP_REQUIRES_OK(context, resource->SerializeToString(&serialized_resource)); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), &output)); + output->scalar()() = serialized_resource; + } +}; + +REGISTER_KERNEL_BUILDER(Name("GetSerializedResourceOp").Device(DEVICE_GPU), + GetSerializedResourceOp); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ec038ebda073c8050321d5668b15a2c6faa72a4b --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc @@ -0,0 +1,80 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +class GetSerializedResourceOpTest : public OpsTestBase {}; + +TEST_F(GetSerializedResourceOpTest, Basic) { + // Create the GPU device. + std::unique_ptr device( + DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0")); + + // Create the resource. + class MySerializableResource : public SerializableResourceBase { + public: + string DebugString() const override { return ""; } + Status SerializeToString(string* serialized) override { + *serialized = "my_serialized_str"; + return Status::OK(); + } + }; + const string container = "mycontainer"; + const string resource_name = "myresource"; + SerializableResourceBase* resource = new MySerializableResource(); + ResourceMgr* rm = device->resource_manager(); + EXPECT_TRUE(rm->Create(container, resource_name, resource).ok()); + + // Create the op. + SetDevice(DEVICE_GPU, std::move(device)); + TF_ASSERT_OK(NodeDefBuilder("op", "GetSerializedResourceOp") + .Input(FakeInput(DT_STRING)) + .Input(FakeInput(DT_STRING)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + + // Execute the op. + AddInputFromArray(TensorShape({}), {container}); + AddInputFromArray(TensorShape({}), {resource_name}); + TF_ASSERT_OK(RunOpKernel()); + + // Verify the result. + // TODO(laigd): OpsTestBase::GetOutput() doesn't work. + Tensor* output = context_->mutable_output(0); + EXPECT_EQ("my_serialized_str", output->scalar()()); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc similarity index 59% rename from tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc rename to tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index bad568644bb1f8d01d4cb0a7c853ec47d6f19e45..f6d387c59cd04aa5c7ccad610290b7b1f1d2b11f 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -12,35 +12,44 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" - #include - -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" -#include "tensorflow/contrib/tensorrt/test/utils.h" +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT #include "cuda/include/cuda_runtime_api.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { static Logger logger; +using absl::StrAppend; +using absl::StrCat; using ::nvinfer1::IRuntime; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; // A helper class to call done() when destructed for asynchronous execution. // Helps simultaneous execution of native and TRT engines. @@ -53,6 +62,83 @@ class AsyncHelper : public tensorflow::core::RefCounted { AsyncOpKernel::DoneCallback done_; }; +// This OP can construct TRTEngine on the fly and if construction of engine +// fails, executes equivalent subgraph as a TensorFlow function. +class TRTEngineOp : public AsyncOpKernel { + public: + explicit TRTEngineOp(OpKernelConstruction* context); + + void ComputeAsync(OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override; + + private: + // Execute calibration + void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper); + + // Construct a function handle for executing native funcdef graph + Status ConstructFunctionHandle(OpKernelContext* ctx); + + // Execute replaced native segment as function Op. + void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); + + // Execute the tensorrt engine. Returns whether we need to retry by running + // the native segment. + bool ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context); + + // Allocate necessary resources for calibration + Status AllocateCalibrationResources(OpKernelContext* ctx, + SerializableResourceBase** cr); + + // Get engine for the input shape + EngineContext* GetEngine(const std::vector& input_shapes, + OpKernelContext* ctx); + + // Return engine batch in cached_engne_batch_sizes_ which is closest to input + // batch. + bool GetCompatibleCachedEngine( + const std::vector& actual_input_shapes, + std::vector* engine_input_shapes); + + std::vector input_nodes_; + std::vector output_nodes_; + + // serialized protobuf segment or trt engine depending on static_engine_ flag. + string serialized_segment_; + + // Name of the function for TF native execution of the segment. + string funcdef_name_; + + // GraphDef representation of the segment. + GraphDef segment_graph_; + + // Engine Precision mode. + TrtPrecisionMode precision_mode_; + + // Whether engine is constructed during the conversion or needs to be + // constructed from protobuf segment. + bool static_engine_; + + // Whether to calibrate INT8 engine. + bool calibration_mode_; + + // Batches of the cached engines + std::vector cached_engine_batches_; + + // Maximum number of cached engines + int max_cached_engines_; + + int64 workspace_size_; + mutex engine_mutex_; + FunctionLibraryRuntime::Handle native_func_; + + // The finalized calibrator for inference. + std::unique_ptr calibrator_; + + // If true, create calibration graph for INT8 mode. Otherwise, we are using + // user-provided quantization ranges. + bool use_calibration_; +}; + #define TYPECASE(dt, X, Y) \ case dt: { \ return (void*)X->flat::Type>().data(); \ @@ -123,20 +209,20 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) context->GetAttr("calibration_data", &calibration_data)); OP_REQUIRES_OK(context, context->GetAttr("segment_funcdef_name", &funcdef_name_)); - OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_)); + OP_REQUIRES_OK(context, + TrtPrecisionModeFromName(precision_string, &precision_mode_)); OP_REQUIRES_OK(context, context->GetAttr("use_calibration", &use_calibration_)); - calibration_mode_ = (use_calibration_ && precision_mode_ == INT8MODE && - calibration_data.size() == 0); - if (calibration_data.size()) { + calibration_mode_ = + (use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 && + calibration_data.empty()); + if (!calibration_data.empty()) { calibrator_.reset(new TRTInt8Calibrator(calibration_data)); calibration_data.resize(0); } native_func_ = tensorflow::kInvalidHandle; OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count", &max_cached_engines_)); - OP_REQUIRES_OK(context, - context->GetAttr("fixed_input_size", &fixed_input_size_)); OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches", &cached_engine_batches_)); std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end()); @@ -167,6 +253,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, opts.rendezvous = ctx->rendezvous(); opts.cancellation_manager = ctx->cancellation_manager(); opts.runner = ctx->runner(); + inputs.reserve(ctx->num_inputs()); for (int i = 0; i < ctx->num_inputs(); i++) { inputs.push_back(ctx->input(i)); } @@ -175,11 +262,13 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, lib->Run(opts, native_func_, inputs, outputs, [this, ctx, outputs, helper](const tensorflow::Status& s) { tensorflow::core::ScopedUnref sc(helper); - VLOG(1) << "Native Segment completed"; if (!s.ok()) { + LOG(ERROR) << "Failed to execute native segment " << this->name() + << ": " << s; ctx->SetStatus(s); return; } + VLOG(1) << "Native Segment completed"; for (size_t t = 0; t < outputs->size(); ++t) { ctx->set_output(t, outputs->at(t)); } @@ -194,19 +283,17 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, VLOG(1) << "Executing TRT calibration: " << name(); helper->Ref(); tensorflow::core::ScopedUnref sc(helper); - // TODO(aaroey): remove the ResourceMgr singleton. - auto trt_rm = TRTResourceManager::instance(); - auto res_mgr = trt_rm->getManager("TRTCalibration"); + auto res_mgr = ctx->resource_manager(); TRTCalibrationResource* calib_res = nullptr; - auto status = res_mgr->LookupOrCreate( - funcdef_name_, "Calibrator", &calib_res, - {[ctx, this](TRTCalibrationResource** cr) -> tensorflow::Status { - return this->AllocateCalibrationResources(ctx, cr); - }}); - if (!status.ok()) { - ctx->SetStatus(status); - return; - } + OP_REQUIRES_OK( + ctx, + res_mgr->LookupOrCreate( + "TF_TRT_Calibration", name(), + reinterpret_cast(&calib_res), + {[ctx, this](SerializableResourceBase** cr) -> tensorflow::Status { + return this->AllocateCalibrationResources(ctx, cr); + }})); + tensorflow::core::ScopedUnref calib_sc(calib_res); int num_inputs = ctx->num_inputs(); // Pass input data to calibrator std::unordered_map input_data; @@ -219,7 +306,8 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, return; } // Check the allocated buffer is sufficient for input - const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); + const auto device_tensor = + calib_res->device_tensors_.at(i).AccessTensor(ctx); CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); input_data.emplace(StrCat(kInputPHName, i), data_address); } @@ -236,32 +324,34 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, ExecuteNativeSegment(ctx, helper); } -int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) { - int num_batch = ctx->input(0).shape().dim_size(0); - int smallest_engine = 0; - for (const auto i : cached_engine_batches_) { - if (i >= num_batch) { - smallest_engine = i; - break; - } - } - // TODO(sami): Need an LRU here - if (smallest_engine == 0) { - if (max_cached_engines_ > cached_engine_batches_.size()) { - smallest_engine = num_batch; - cached_engine_batches_.push_back(num_batch); - VLOG(1) << "Running with batch size " << num_batch; - } else { - string msg = - StrCat("Engine buffer is full. buffer limit=", max_cached_engines_, - ", current entries="); - for (auto i : cached_engine_batches_) StrAppend(&msg, i, ","); - StrAppend(&msg, " requested batch=", num_batch); - LOG(WARNING) << msg; - return -1; +bool TRTEngineOp::GetCompatibleCachedEngine( + const std::vector& actual_input_shapes, + std::vector* engine_input_shapes) { + const int batch_size = actual_input_shapes[0].dim_size(0); + int smallest_batch_size = -1; + // Output shape will always be the same as the input but we will overwrite the + // batch size. + *engine_input_shapes = actual_input_shapes; + for (const int cached_batch_size : cached_engine_batches_) { + // Check if compatible: batch <= cached batch. + // + // TODO(laigd): here it only compare the first dim a.k.a the batch size, + // we'll need to to support non-batch dimensions as well. This will be done + // as part of the offline conversion implementation. + if (batch_size <= cached_batch_size) { + // First case: first compatible engine found + // Second case: smaller batch size engine found + if ((smallest_batch_size == -1) || + (cached_batch_size < smallest_batch_size)) { + smallest_batch_size = cached_batch_size; + // Overwrite batch size for output + for (int i = 0; i < engine_input_shapes->size(); i++) { + (*engine_input_shapes)[i].set_dim(0, smallest_batch_size); + } + } } } - return smallest_engine; + return (smallest_batch_size != -1); } void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, @@ -272,25 +362,21 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, ExecuteCalibration(ctx, helper); return; } - const int smallest_engine = GetEngineBatch(ctx); - if (smallest_engine < 0) { - LOG(WARNING) << "Failed to get engine batch, running native segment for " - << name(); - ExecuteNativeSegment(ctx, helper); - return; + // Get shapes of inputs to engine. + std::vector input_shapes; + input_shapes.reserve(ctx->num_inputs()); + for (int i = 0; i < ctx->num_inputs(); ++i) { + input_shapes.push_back(ctx->input(i).shape()); } - - const int num_batch = ctx->input(0).shape().dim_size(0); - auto& engine_ctx_pair = GetEngine(smallest_engine, ctx); - auto& trt_engine_ptr = engine_ctx_pair.first; - if (!trt_engine_ptr) { - LOG(WARNING) << "Engine retrieval for batch size " << num_batch + EngineContext* engine_context = GetEngine(input_shapes, ctx); + if (!engine_context->cuda_engine) { + LOG(WARNING) << "Engine retrieval for input shapes: " + << TensorShapeUtils::ShapeListString(input_shapes) << " failed. Running native segment for " << name(); ExecuteNativeSegment(ctx, helper); return; } - const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(), - engine_ctx_pair.second.get()); + const bool retry = ExecuteTrtEngine(ctx, engine_context); if (retry) { LOG(WARNING) << "Failed to execute engine, " << "retrying with native segment for " << name(); @@ -299,18 +385,19 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, } } -bool TRTEngineOp::ExecuteTrtEngine( - OpKernelContext* ctx, const int num_batch, - nvinfer1::ICudaEngine* trt_engine_ptr, - nvinfer1::IExecutionContext* trt_execution_context_ptr) { +bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, + EngineContext* engine_context) { VLOG(1) << "Executing TRT engine: " << name(); + auto& cuda_engine = engine_context->cuda_engine; const bool kRetry = true; + // All inputs must have the same batch size, so just get it from the first + // input. + const int num_batch = ctx->input(0).shape().dim_size(0); const int num_binding = ctx->num_inputs() + ctx->num_outputs(); std::vector buffers(num_binding); for (int i = 0; i < ctx->num_inputs(); i++) { const string input_name = StrCat(kInputPHName, i); - const int binding_index = - trt_engine_ptr->getBindingIndex(input_name.c_str()); + const int binding_index = cuda_engine->getBindingIndex(input_name.c_str()); if (binding_index == -1) { LOG(ERROR) << "Input node not found, at " << input_name; return kRetry; @@ -323,10 +410,11 @@ bool TRTEngineOp::ExecuteTrtEngine( << " vs " << input_shape.dim_size(0); return kRetry; } - auto dtype = trt_engine_ptr->getBindingDataType(binding_index); + auto dtype = cuda_engine->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: - buffers[binding_index] = (void*)(input_tensor.flat().data()); + buffers[binding_index] = + const_cast(input_tensor.flat().data()); break; case nvinfer1::DataType::kHALF: LOG(ERROR) << "FP16 inputs are not supported yet!"; @@ -335,10 +423,11 @@ bool TRTEngineOp::ExecuteTrtEngine( LOG(ERROR) << "INT8 inputs are not supported yet!"; return kRetry; case nvinfer1::DataType::kINT32: - buffers[binding_index] = (void*)(input_tensor.flat().data()); + buffers[binding_index] = + const_cast(input_tensor.flat().data()); break; default: - LOG(ERROR) << "Unknown TRT data type: " << int(dtype); + LOG(ERROR) << "Unknown TRT data type: " << static_cast(dtype); return kRetry; } } @@ -346,13 +435,12 @@ bool TRTEngineOp::ExecuteTrtEngine( for (int i = 0; i < ctx->num_outputs(); i++) { // Create an output tensor const string output_name = StrCat(kOutputPHName, i); - const int binding_index = - trt_engine_ptr->getBindingIndex(output_name.c_str()); + const int binding_index = cuda_engine->getBindingIndex(output_name.c_str()); Tensor* output_tensor = nullptr; TensorShape output_shape; if (binding_index != -1) { - auto dims = trt_engine_ptr->getBindingDimensions(binding_index); + auto dims = cuda_engine->getBindingDimensions(binding_index); std::vector trt_shape(dims.nbDims + 1); trt_shape[0] = num_batch; for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j]; @@ -374,11 +462,11 @@ bool TRTEngineOp::ExecuteTrtEngine( // TODO(aaroey): ideally we should retry, fix this. return !kRetry; } - auto dtype = trt_engine_ptr->getBindingDataType(binding_index); + auto dtype = cuda_engine->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = - reinterpret_cast(output_tensor->flat().data()); + const_cast(output_tensor->flat().data()); break; case nvinfer1::DataType::kHALF: LOG(WARNING) << "half size is not supported yet!"; @@ -388,7 +476,7 @@ bool TRTEngineOp::ExecuteTrtEngine( return kRetry; case nvinfer1::DataType::kINT32: buffers[binding_index] = - reinterpret_cast(output_tensor->flat().data()); + const_cast(output_tensor->flat().data()); break; default: LOG(WARNING) << "Unknown TRT data type: " << static_cast(dtype); @@ -402,9 +490,12 @@ bool TRTEngineOp::ExecuteTrtEngine( ->implementation() ->GpuStreamMemberHack())); + // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex + // for it. + tensorflow::mutex_lock lock(engine_context->mu); // TODO(jie): trt enqueue does not return error - auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream, - nullptr); + auto ret = engine_context->execution_context->enqueue(num_batch, &buffers[0], + *stream, nullptr); if (!ret) { LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name(); return kRetry; @@ -414,50 +505,45 @@ bool TRTEngineOp::ExecuteTrtEngine( return !kRetry; } -TRTEngineOp::~TRTEngineOp() { - // We need to manually destroy the engine and execution context before - // the allocator is destructed. - for (auto& eng : engine_map_) { - eng.second.first.reset(); - eng.second.second.reset(); +EngineContext* TRTEngineOp::GetEngine( + const std::vector& input_shapes, OpKernelContext* ctx) { + static EngineContext empty_context; + tensorflow::mutex_lock lock(engine_mutex_); + // TODO(tmorris): using first input to get batch size - is this reliable? + const int batch_size = input_shapes[0].dim_size(0); + + // Get engine cache + TRTEngineCacheResource* cache_res = nullptr; + auto status = ctx->resource_manager()->LookupOrCreate( + "TRTEngineCache", funcdef_name_, &cache_res, + {[this, ctx](TRTEngineCacheResource** cr) -> tensorflow::Status { + *cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_); + return Status::OK(); + }}); + if (!status.ok()) { + ctx->SetStatus(status); + return &empty_context; } - allocator_.reset(); -} - -nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) { - if (allocator_) return allocator_.get(); - auto device = ctx->device(); - auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes()); - if (!alloc) { - LOG(ERROR) << "Can't find device allocator for gpu device " - << device->name(); - return nullptr; + tensorflow::core::ScopedUnref sc(cache_res); + auto& cache = cache_res->cache_; + auto allocator = cache_res->allocator_.get(); + if (allocator == nullptr) { + return &empty_context; } - allocator_.reset(new TRTDeviceAllocator(alloc)); - return allocator_.get(); -} - -TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, - OpKernelContext* ctx) { - static EngineCtxPair null_pair = { - TrtUniquePtrType(nullptr), - TrtUniquePtrType(nullptr)}; - // TODO(sami): This method needs to be re-written to use resource manager and - // with LRU mechanism option. - tensorflow::mutex_lock lock(engine_mutex_); + // Handle the static engine case. For static engines, the cache will have a + // single element containing the only engine. if (static_engine_) { - if (engine_map_.size()) { - if (engine_map_.begin()->first >= batch_size) { - return engine_map_.begin()->second; + if (cache.size()) { + // Batch size of engine must be >= the input batch size + // TODO(tmorris): use match compatible function? + if (cache.begin()->first[0].dim_size(0) >= batch_size) { + return cache.begin()->second.get(); } - return null_pair; + return &empty_context; } + TrtUniquePtrType infer(nvinfer1::createInferRuntime(logger)); - auto allocator = GetAllocator(ctx); - if (allocator == nullptr) { - return null_pair; - } infer->setGpuAllocator(allocator); TrtUniquePtrType static_engine( infer->deserializeCudaEngine(serialized_segment_.c_str(), @@ -465,62 +551,87 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, PluginFactoryTensorRT::GetInstance())); auto raw_static_engine = static_engine.get(); const auto max_batch_size = raw_static_engine->getMaxBatchSize(); - engine_map_[max_batch_size] = { - std::move(static_engine), - TrtUniquePtrType( - raw_static_engine->createExecutionContext())}; + // Static engine will have max_batch_size for batch size so that all inputs + // will map to this single engine. + std::vector engine_input_shapes(input_shapes); + for (int i = 0; i < engine_input_shapes.size(); i++) { + // TODO(tmorris): will all inputs have batch size as first dimension?? + engine_input_shapes[i].set_dim(0, max_batch_size); + } + // TODO(laigd): here we assume engine_input_shapes matches the actual input + // shapes of the engine, we should verify that. + cache.emplace(engine_input_shapes, + absl::make_unique( + std::move(static_engine), + TrtUniquePtrType( + raw_static_engine->createExecutionContext()))); // Runtime is safe to delete after engine creation serialized_segment_.clear(); if (max_batch_size < batch_size) { - return null_pair; + return &empty_context; } - return engine_map_.at(max_batch_size); + return cache.at(engine_input_shapes).get(); } // static_engine_ // Handle the dynamic engine case. - auto engine_it = engine_map_.find(batch_size); - if (engine_it == engine_map_.end() && - engine_map_.size() < (size_t)max_cached_engines_) { - nvinfer1::IGpuAllocator* allocator = nullptr; - allocator = GetAllocator(ctx); - if (allocator == nullptr) { - return null_pair; - } - std::vector shapes; - for (int i = 0; i < ctx->num_inputs(); ++i) { - shapes.emplace_back(ctx->input(i).shape()); + // See if there is a compatible engine cached. The batch size should be <= the + // cached batch size. + std::vector engine_input_shapes; + const bool matched_successfully = + GetCompatibleCachedEngine(input_shapes, &engine_input_shapes); + // If matched, use that engine. Otherwise, we will look in cache for that + // exact shape and possibly create a new engine if it is not in cache. + if (!matched_successfully) { + engine_input_shapes = input_shapes; + if (!cached_engine_batches_.empty()) { + // If user has explicitly defined cached_engine_batches, we should + // warn them that their input was non-compatible (batch size too high) + LOG(WARNING) << "No compatible cached engine was found for batch size: " + << batch_size << ". A new engine will be created."; + cached_engine_batches_.push_back(batch_size); } + } + + if (!cache.count(engine_input_shapes)) { TrtUniquePtrType engine; bool convert_successfully = false; LOG(INFO) << "Building a new TensorRT engine for " << name() - << " with batch size " << batch_size; + << " input shapes: " + << TensorShapeUtils::ShapeListString(engine_input_shapes); + + // Convert to partial shapes + std::vector partial_shapes(engine_input_shapes.begin(), + engine_input_shapes.end()); + // Up to this point, calibrator_ can never be empty, since otherwise it // means calibration_mode_ is true and this path won't get executed. auto status = convert::ConvertGraphDefToEngine( - segment_graph_, precision_mode_, batch_size, workspace_size_, shapes, - &logger, allocator, calibrator_.get(), &engine, use_calibration_, - &convert_successfully); + segment_graph_, precision_mode_, batch_size, workspace_size_, + partial_shapes, &logger, allocator, calibrator_.get(), &engine, + use_calibration_, &convert_successfully); if (!status.ok()) { if (convert_successfully) { // This means it fail to build the engine even when the network is built // successfully, probably due to internal issues. In this case we don't // retry in the future. - engine_map_[batch_size] = {nullptr, nullptr}; + cache.emplace(engine_input_shapes, absl::make_unique()); } LOG(WARNING) << "Engine creation for batch size " << batch_size << " failed " << status; - return null_pair; + return &empty_context; } VLOG(1) << "Conversion is done"; TrtUniquePtrType exec_context( engine->createExecutionContext()); - engine_map_[batch_size] = {std::move(engine), std::move(exec_context)}; + cache.emplace(engine_input_shapes, + absl::make_unique(std::move(engine), + std::move(exec_context))); } - return engine_map_.at(batch_size); + return cache.at(engine_input_shapes).get(); } tensorflow::Status TRTEngineOp::AllocateCalibrationResources( - OpKernelContext* ctx, TRTCalibrationResource** cr) { + OpKernelContext* ctx, SerializableResourceBase** cr) { auto cres = new TRTCalibrationResource(); *cr = cres; // Get the allocator. @@ -536,7 +647,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( const int batch_size = ctx->input(0).dim_size(0); const int num_inputs = ctx->num_inputs(); std::vector shapes; - dev_tensors_.resize(num_inputs); + cres->device_tensors_.resize(num_inputs); VLOG(1) << " Constructing calibrator"; for (int i = 0; i < num_inputs; i++) { // allocate workspace on device for inputs @@ -544,19 +655,19 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( shapes.emplace_back(t.shape()); Tensor* device_tensor; TF_RETURN_IF_ERROR(ctx->allocate_persistent( - t.dtype(), t.shape(), &dev_tensors_.at(i), &device_tensor)); + t.dtype(), t.shape(), &cres->device_tensors_.at(i), &device_tensor)); CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); void* device_address = GetTensorAddress(device_tensor); if (device_address == nullptr) { return tensorflow::errors::InvalidArgument( "Unsupported data type encountered in input ", i); } - device_buffers_.emplace( + cres->device_buffers_.emplace( StrCat(kInputPHName, i), std::pair(device_address, device_tensor->TotalBytes())); } cres->calibrator_.reset( - new TRTInt8Calibrator(device_buffers_, batch_size, name())); + new TRTInt8Calibrator(cres->device_buffers_, batch_size, name())); const string label(name()); auto segment_graph = &segment_graph_; const int platform_gpu_id = @@ -585,9 +696,10 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( // TODO(aaroey): maybe setting the max batch size using the python // calibration wrapper class. auto s = convert::ConvertGraphDefToEngine( - *segment_graph, INT8MODE, cres->calibrator_->getBatchSize(), - workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(), - cres->calibrator_.get(), &cres->engine_, + *segment_graph, TrtPrecisionMode::INT8, + cres->calibrator_->getBatchSize(), workspace_size_bytes, shapes, + &cres->logger_, cres->allocator_.get(), cres->calibrator_.get(), + &cres->engine_, /*use_calibration=*/true, /*convert_successfully=*/nullptr); if (!s.ok()) { diff --git a/tensorflow/compiler/tf2tensorrt/ops/get_serialized_resource_op.cc b/tensorflow/compiler/tf2tensorrt/ops/get_serialized_resource_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..59da73f5efc8eedc20c35cf35cb1eae6cda136c9 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/ops/get_serialized_resource_op.cc @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { + +REGISTER_OP("GetSerializedResourceOp") + .Input("container: string") + .Input("resource_name: string") + .Output("serialized_resource: string") + .SetShapeFn(shape_inference::ScalarShape) + .SetIsStateful() + .Doc(R"doc( +Gets a resource from a container managed by the resource manager and returns +its serialized representation. +)doc"); + +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc similarity index 80% rename from tensorflow/contrib/tensorrt/ops/trt_engine_op.cc rename to tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc index 92405906eb76b043bc08b68e25e16ab40197dddf..b84d2fe0b8cef3475f2a7d0f5383d5e11cde099a 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc @@ -28,16 +28,22 @@ namespace shape_inference { extern Status TRTEngineOpShapeInference(InferenceContext* c); } +// NOTE: please try NOT to add/modify/remove attributes or inputs/outputs to the +// list below, this will break backward compatibility! +// +// TODO(laigd): consider making this op stateful. The only problem is it uses TF +// function which has to be stateless, but we can use function library as the +// key to cache the instantiated functions for different executor subgraphs. REGISTER_OP("TRTEngineOp") .Attr("serialized_segment: string") .Attr("input_shapes: list(shape)") .Attr("output_shapes: list(shape)") .Attr("segment_funcdef_name: string") - .Attr("InT: list({int8,float16,float32})") - .Attr("OutT: list({int8,float16,float32})") + .Attr("InT: list({int8,float16,float32,int32})") + .Attr("OutT: list({int8,float16,float32,int32})") .Attr("static_engine: bool = true") .Attr("fixed_input_size: bool = true") - .Attr("cached_engine_batches: list(int) = []") + .Attr("cached_engine_batches: list(int) >= 0 = []") .Attr("max_cached_engines_count: int = 1") .Attr("workspace_size_bytes: int") .Attr("precision_mode: {'FP32', 'FP16', 'INT8'}") diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc similarity index 96% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc index 062f86e8bb4dc753925e4e2baf0bc80a5312a94f..a4341c530fffca88c82813cc2ace2c0ae1df5345 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" + #include #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" + +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h similarity index 92% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin.h rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h index 754920b60ca7439513a91ad0354833a2482b29c1..f495d857037c79a1783f8eb232fb57c20e229169 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ #include #include @@ -71,4 +71,4 @@ class PluginTensorRT : public nvinfer1::IPlugin { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.cc similarity index 96% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.cc index cccc91226265ed139fb8db0b71c40b868f729562..871fb1210bd495dc3f5e8153bb6c3a361bf569f5 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h similarity index 91% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h index bbae9fb65c22cf69d2e7954436fd04dd16f7f6c8..9aa99a40b80de92a4d9b9ad36e88e693b8aa42dc 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ #include #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -99,4 +99,4 @@ class TrtPluginRegistrar { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc similarity index 96% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc index 129bdcdbc2f8d9d5215f45f381bcadf35e4fa75e..7d9c465c22beed0e252cbc26d6c533a0789d4f49 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.cc similarity index 94% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.cc index a8f60886c03c174a612e7a135b6eb7bb7cb9997a..f3d6b4ff476139693a5251ddf58a3200d8af8efc 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h" #include #if GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h similarity index 82% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h index 274ce42fec9283c643004d45fba461879fc5f2dc..e5eff15c19694093c7a5ea933a41375e8e01c8b9 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" #include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA @@ -43,4 +43,4 @@ string ExtractOpName(const void* serial_data, size_t serial_length, #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ diff --git a/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..25fb3a13db9911673bac04652b8ed8ba842be93c --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py @@ -0,0 +1,69 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Exposes the Python wrapper of TRTEngineOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +import platform +from tensorflow.python.framework import errors + +_trt_ops_so = None +_module_lock = threading.Lock() + + +def load_trt_ops(): + """Load TF-TRT op libraries so if it hasn't been loaded already.""" + global _trt_ops_so + + if platform.system() == "Windows": + raise RuntimeError("Windows platforms are not supported") + + with _module_lock: + if _trt_ops_so: + return + + try: + # pylint: disable=g-import-not-at-top,unused-variable + # This registers the TRT ops, it doesn't require loading TRT library. + from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op + # pylint: enable=g-import-not-at-top,unused-variable + except ImportError as e: + print("**** Failed to import TF-TRT ops. This is because the binary was " + "not built with CUDA or TensorRT enabled. ****") + raise e + + # TODO(laigd): we should load TF-TRT kernels here as well after removing the + # swig binding. + try: + # pylint: disable=g-import-not-at-top + from tensorflow.python.framework import load_library + from tensorflow.python.platform import resource_loader + # pylint: enable=g-import-not-at-top + + _trt_ops_so = load_library.load_op_library( + resource_loader.get_path_to_datafile("_trt_ops.so")) + except errors.NotFoundError as e: + no_trt_message = ( + "**** Failed to initialize TensorRT. This is either because the " + "TensorRT installation path is not in LD_LIBRARY_PATH, or because " + "you do not have it installed. If not installed, please go to " + "https://developer.nvidia.com/tensorrt to download and install " + "TensorRT ****") + print(no_trt_message) + raise e diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc similarity index 92% rename from tensorflow/contrib/tensorrt/segment/segment.cc rename to tensorflow/compiler/tf2tensorrt/segment/segment.cc index 6abc5226ccf96e472df77269bee6186726e5768d..3794929b1df3fa999de6ab218dc2ddfb96e4ac81 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/compiler/tf2tensorrt/segment/segment.h" #include #include #include #include -#include "tensorflow/contrib/tensorrt/segment/union_find.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/segment/union_find.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -29,11 +30,14 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + namespace tensorflow { namespace tensorrt { namespace segment { -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; // A simple graph representation to mirror tensorflow::Graph. This structure // helps saving memory since segmenter modifies the graph in place, preventing @@ -225,6 +229,24 @@ SimpleGraph::~SimpleGraph() { for (auto x : edges_) delete x; } +// Define comparison functions for std::set with pointer keys so that behavior +// is deterministic. When using std::set with pointer key types, the items are +// sorted by pointer address which is non-deterministic. This can cause issues +// for INT8 mode because the graph is converted twice and non-determinism may +// cause a mismatch between the calibration tables of the conversions. +struct SimpleEdgePtrCompare { + bool operator()(const SimpleEdge* lhs, const SimpleEdge* rhs) const { + return lhs->id() < rhs->id(); + } +}; + +struct NodePtrCompare { + bool operator()(const tensorflow::Node* lhs, + const tensorflow::Node* rhs) const { + return lhs->name() < rhs->name(); + } +}; + namespace { // Copied from TF ReverseDFS, which only works for tensorflow::Graph. @@ -476,7 +498,7 @@ tensorflow::Status SegmentGraph( // nodes. Iterate since combining two nodes may unblock other // combining. while (true) { - std::set contract_edges; + std::set contract_edges; for (const SimpleEdge* out_edge : node->out_edges()) { VLOG(3) << "... out node " << out_edge->dst()->name() << " ( " << out_edge->dst()->id() << " <- " << node->id() << " )"; @@ -530,7 +552,7 @@ tensorflow::Status SegmentGraph( // A map from the segment identifier (currently the name of the root node of // the segment tree) to the segment nodes set. - std::map> sg_map; + std::map> sg_map; // A map from the segment identifier (currently the name of the root node of // the segment tree) to the device names that the nodes in the segment are @@ -566,7 +588,8 @@ tensorflow::Status SegmentGraph( // --------------------------------- Step 2 --------------------------------- // Remove ineligible input/output nodes. for (auto& itr : sg_map) { - std::set& segment_nodes = itr.second; + std::set& segment_nodes = + itr.second; VLOG(1) << "Segment original size: " << segment_nodes.size(); while (true) { std::deque in_nodes_que, out_nodes_que; @@ -618,8 +641,9 @@ tensorflow::Status SegmentGraph( bool is_input_nodes, std::deque* que) { // Run a BFS on the queue to find all the input/output nodes. - std::set visited; - std::set logged(que->begin(), que->end()); + std::set visited; + std::set logged(que->begin(), + que->end()); while (!que->empty()) { auto node = que->front(); que->pop_front(); @@ -653,9 +677,11 @@ tensorflow::Status SegmentGraph( // --------------------------------- Step 3 --------------------------------- // Convert the segments into the expected return format for (const auto& itr : sg_map) { - const std::set& segment_nodes = itr.second; + const string& segment_root = itr.first; + // Return format does not require set comparator. + std::set segment_nodes(itr.second.begin(), itr.second.end()); if (VLOG_IS_ON(1)) { - string s = "parent=" + itr.first + ":"; + string s = "parent=" + segment_root + ":"; for (auto node : segment_nodes) s += " " + node->name(); VLOG(1) << "Segment " << segments->size() << ": " << s; } @@ -668,12 +694,10 @@ tensorflow::Status SegmentGraph( } // TODO(sami): Make segmenter placement aware once trtscopes are in place - std::set segment_node_names; - for (auto node : itr.second) segment_node_names.insert(node->name()); - const auto& dev_itr = device_maps.find(itr.first); + const auto& dev_itr = device_maps.find(segment_root); if (dev_itr == device_maps.end() || dev_itr->second.empty()) { VLOG(1) << "No device assigned to segment " << segments->size(); - segments->emplace_back(std::make_pair(segment_node_names, string())); + segments->emplace_back(std::make_pair(segment_nodes, string())); } else if (dev_itr->second.size() > 1) { string s("Segment "); StrAppend(&s, segments->size(), " has multiple devices attached: "); @@ -682,10 +706,10 @@ tensorflow::Status SegmentGraph( } LOG(WARNING) << s << " choosing " << *(dev_itr->second.begin()); segments->emplace_back( - std::make_pair(segment_node_names, *(dev_itr->second.begin()))); + std::make_pair(segment_nodes, *(dev_itr->second.begin()))); } else { segments->emplace_back( - std::make_pair(segment_node_names, *(dev_itr->second.begin()))); + std::make_pair(segment_nodes, *(dev_itr->second.begin()))); } } if (VLOG_IS_ON(1)) { @@ -704,3 +728,6 @@ tensorflow::Status SegmentGraph( } // namespace segment } // namespace tensorrt } // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/compiler/tf2tensorrt/segment/segment.h similarity index 81% rename from tensorflow/contrib/tensorrt/segment/segment.h rename to tensorflow/compiler/tf2tensorrt/segment/segment.h index b9693aad1b764515459db6833b05221ea5b3a2d1..9622ddd593990e93ba1b54e9dfd0052006e20ced 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ #include #include @@ -24,15 +24,17 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" -namespace tensorflow { +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +namespace tensorflow { namespace tensorrt { namespace segment { -// Vector of segments, each entry contains a set of node names and a device name -// in the segment. -// TODO(aaroey): use node pointer instead of node name. -using SegmentNodesVector = std::vector, string>>; +// Vector of segments, each entry contains a set of node pointers and a device +// name in the segment. +using SegmentNodesVector = + std::vector, string>>; struct SegmentOptions { // Segment must contain at least this many nodes. @@ -60,4 +62,7 @@ tensorflow::Status SegmentGraph( } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc similarity index 97% rename from tensorflow/contrib/tensorrt/segment/segment_test.cc rename to tensorflow/compiler/tf2tensorrt/segment/segment_test.cc index 4805ef9c61a7784a1c08cf5eaf504691bc9dbedc..e11ad2719740d908f93ef580a6b308469365f402 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/compiler/tf2tensorrt/segment/segment.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -26,6 +26,9 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + namespace tensorflow { namespace tensorrt { namespace segment { @@ -75,7 +78,10 @@ class SegmentTest : public ::testing::Test { const std::vector>& expected_segments) { EXPECT_EQ(expected_segments.size(), segments.size()); for (int i = 0; i < segments.size(); ++i) { - const auto& segment_node_names = segments[i].first; + std::set segment_node_names; + for (const Node* node : segments[i].first) { + segment_node_names.insert(node->name()); + } const auto& expected = expected_segments[i]; for (const auto& name : expected) { EXPECT_TRUE(segment_node_names.count(name)) @@ -262,3 +268,6 @@ TEST_F(SegmentTest, BigIfElse) { } // namespace segment } // namespace tensorrt } // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/segment/union_find.h b/tensorflow/compiler/tf2tensorrt/segment/union_find.h similarity index 92% rename from tensorflow/contrib/tensorrt/segment/union_find.h rename to tensorflow/compiler/tf2tensorrt/segment/union_find.h index 1c64ebbb0ae532a4776ab8963515d19fd3b23b4c..6458ae692fd7c922b5fc3bea2e55b613447dbde0 100644 --- a/tensorflow/contrib/tensorrt/segment/union_find.h +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ namespace tensorflow { namespace tensorrt { @@ -76,4 +76,4 @@ UnionFind* UnionFind::FindRoot() { } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc similarity index 100% rename from tensorflow/contrib/tensorrt/tensorrt_test.cc rename to tensorflow/compiler/tf2tensorrt/tensorrt_test.cc diff --git a/tensorflow/contrib/tensorrt/test/utils.cc b/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc similarity index 94% rename from tensorflow/contrib/tensorrt/test/utils.cc rename to tensorflow/compiler/tf2tensorrt/utils/test_utils.cc index 276308b3a0a6ce864969afb0179c6a3f00d6b70b..dd3c09d7e42358a1f9e6cc13be6198de58e38963 100644 --- a/tensorflow/contrib/tensorrt/test/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/test/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h" #include #include #include "re2/re2.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/test/utils.h b/tensorflow/compiler/tf2tensorrt/utils/test_utils.h similarity index 85% rename from tensorflow/contrib/tensorrt/test/utils.h rename to tensorflow/compiler/tf2tensorrt/utils/test_utils.h index 4bb4120206cfaae70107e55d1818e3af2f02717a..d85875991b79014c4f173d3157ed02e6c96f045c 100644 --- a/tensorflow/contrib/tensorrt/test/utils.h +++ b/tensorflow/compiler/tf2tensorrt/utils/test_utils.h @@ -13,11 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace tensorrt { @@ -41,4 +40,4 @@ string GetTestValue(const string& label); } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc similarity index 98% rename from tensorflow/contrib/tensorrt/resources/trt_allocator.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc index 7a2e93414aed56525eaeac876cdac20404bcf6ab..1636cdc30c4df157ed124b160449af645f917252 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h similarity index 93% rename from tensorflow/contrib/tensorrt/resources/trt_allocator.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h index f857a9de055ee7668f0bf9bc97e030354505081b..59ffb42bad348c78cde32035aff8c7081528b3a6 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ #include @@ -81,4 +81,4 @@ class TRTDeviceAllocator : public TRTBaseAllocator { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc similarity index 80% rename from tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc index ad6b1d7d4c57d696d3dee3b479733e152e669211..e457c64928e5df84c7e2726ba3621420f013dbc9 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/core/platform/test.h" @@ -48,11 +48,14 @@ TEST(TRTAllocatorTest, Align) { 513ul, 700ul, 12345ul, 1ul << 32}) { for (uint64_t alignment = 1; alignment <= space * 4; alignment *= 2) { for (const uintptr_t ptr_val : - {1ul, alignment == 1 ? 1ul : alignment - 1, alignment, alignment + 1, - alignment + (alignment / 2)}) { + {static_cast(1), + alignment == 1 ? static_cast(1) : alignment - 1, + alignment, alignment + 1, alignment + (alignment / 2)}) { if (ptr_val % alignment == 0) { for (const uint64_t size : - {1ul, space == 1 ? 1ul : space - 1, space, space + 1}) { + {static_cast(1), + space == 1 ? static_cast(1) : space - 1, space, + space + 1}) { EXPECT_EQ(space >= size, RunTest(alignment, size, ptr_val, space)); } } else { @@ -62,8 +65,10 @@ TEST(TRTAllocatorTest, Align) { EXPECT_TRUE( RunTest(alignment, space - diff, ptr_val + diff, space - diff)); for (const uint64_t size : - {1ul, space - diff > 1 ? space - diff - 1 : 1ul, space - diff, - space - diff + 1, space - 1}) { + {static_cast(1), + space - diff > 1 ? space - diff - 1 + : static_cast(1), + space - diff, space - diff + 1, space - 1}) { EXPECT_EQ(space - diff >= size, RunTest(alignment, size, ptr_val, space)); } diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc similarity index 97% rename from tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc index dab1dd9343be7d5b033a3e04bf0b49fbbf37e9e5..5213fced1ea9220422245172f5b4a3f584a2a566 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" #include #include @@ -135,7 +135,7 @@ void TRTInt8Calibrator::setDone() { void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, std::size_t length) { - calibration_table_ = string((const char*)ptr, length); + calibration_table_ = string(static_cast(ptr), length); VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr << " length=" << length; } diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h similarity index 87% rename from tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h index 65466c9741989fda5f82fc27d813d026f35fe386..aa70b07f8d79848c362275815004db32cca128be 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ #include #include @@ -34,7 +34,12 @@ namespace tensorrt { // TRTs pull model for calibration. When TRT implements a means for // a push calibration This class should be updated accordingly +// IInt8EntropyCalibrator2 is prefferred for TRT 5.1+. +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) +struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator2 { +#else struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { +#endif public: // Construct a calibrator for future calibration. TRTInt8Calibrator( @@ -96,4 +101,4 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { #endif #endif -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc similarity index 90% rename from tensorflow/contrib/tensorrt/log/trt_logger.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc index dda0dc9e712eb726800abfb6084f4f708d04825b..6bc842ed5ca7e03018157060a332338cdc926f14 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -26,6 +26,9 @@ namespace tensorrt { void Logger::log(Severity severity, const char* msg) { // Suppress info-level messages switch (severity) { +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + case Severity::kVERBOSE: +#endif case Severity::kINFO: { // Mark TRT info messages as debug! VLOG(2) << name_ << " " << msg; break; diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h similarity index 86% rename from tensorflow/contrib/tensorrt/log/trt_logger.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_logger.h index 96ccacb791e40143c5c4d9d691bb353702f9a28b..22f4de970a80765b0e1e7e8816134d83aaec7c73 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ #include "tensorflow/core/platform/types.h" @@ -41,4 +41,4 @@ class Logger : public nvinfer1::ILogger { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..09c47b36b0ad8074e749342e7d08f139da7ea1f4 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h @@ -0,0 +1,192 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ + +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/errors.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +template +class LRUCache { + public: + typedef Value value_type; + typedef Key key_type; + typedef HashFunction hasher; + typedef typename std::unordered_map map_type; + typedef typename map_type::iterator iterator; + typedef typename map_type::const_iterator const_iterator; + + LRUCache() : capacity_(0) {} + explicit LRUCache(size_t capacity) : capacity_(capacity) {} + + size_t capacity() const { return capacity_; } + + void reserve(size_t capacity) { + capacity_ = capacity; + DiscardOld(); + } + + size_t size() const { return objects_.size(); } + + size_t count(const key_type& key) const { return objects_.count(key); } + + value_type& at(const key_type& key) { return Touch(key); } + + const_iterator begin() const { return objects_.begin(); } + const_iterator end() const { return objects_.end(); } + + iterator begin() { return objects_.begin(); } + iterator end() { return objects_.end(); } + + template + std::pair emplace(Args&&... args) { + DiscardOld(1); + std::pair result = + objects_.emplace(std::forward(args)...); + key_type key = result.first->first; + if (result.second) { + keys_.push_front(key); + } else { + TouchNoCheck(key); // The key must exist in this case. + } + return result; + } + + private: + std::unordered_map objects_; + std::list keys_; + size_t capacity_; + value_type not_found_value_; + + value_type& Touch(const key_type& key) { + // Check that the key exists, and let it return std::out_of_range error if + // not. + value_type& value = objects_.at(key); + TouchNoCheck(key); + return value; + } + + void TouchNoCheck(const key_type& key) { + auto rank = std::find(keys_.begin(), keys_.end(), key); + if (rank != keys_.begin()) { + keys_.erase(rank); + keys_.push_front(key); + } + } + + // Creates n free positions in cache + tensorflow::Status DiscardOld(size_t n = 0) { + if (n > capacity_) { + return tensorflow::errors::Internal( + "Insufficient capacity in cache (capacity = ", capacity_, + ", requested ", n, ")"); + } + while (objects_.size() > (capacity_ - n)) { + key_type discard_key = keys_.back(); + keys_.pop_back(); + objects_.erase(discard_key); + } + return tensorflow::Status::OK(); + } +}; + +// Define a hash function for vector because it is used as the key +// for the engine cache. +struct VectorTensorShapeHasher { + std::size_t operator()( + const std::vector& key) const { + return std::hash()(TensorShapeUtils::ShapeListString(key)); + } +}; + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +struct EngineContext { + EngineContext() {} // Creates an empty context. + EngineContext( + TrtUniquePtrType&& input_cuda_engine, + TrtUniquePtrType&& input_execution_context) + : cuda_engine(std::move(input_cuda_engine)), + execution_context(std::move(input_execution_context)) {} + + mutex mu; + TrtUniquePtrType cuda_engine; + TrtUniquePtrType execution_context + GUARDED_BY(mu); +}; + +class TRTEngineCacheResource : public tensorflow::ResourceBase { + public: + TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity) + : cache_(capacity) { + auto device = ctx->device(); + auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes()); + if (!alloc) { + LOG(ERROR) << "Can't find device allocator for gpu device " + << device->name(); + allocator_ = nullptr; + } else { + allocator_.reset(new TRTDeviceAllocator(alloc)); + } + } + + string DebugString() const override { + std::stringstream oss; + using std::dec; + using std::endl; + using std::hex; + oss << "TRTEngineCacheResource: "; + oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", "; + oss << "LRUCache = " << hex << &cache_ << dec << endl; + oss << "Containing " << cache_.size() << " entries: " << endl; + for (const auto& item : cache_) { + oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex + << "ICudaEngine: " << item.second.get()->cuda_engine.get() << ", " + << "IExecutionContext: " << item.second.get()->execution_context.get() + << dec << endl; + } + return oss.str(); + } + + // Keep device allocator for TRT. + std::unique_ptr allocator_; + + // Declare cache after allocator so that it is destroyed before allocator is. + LRUCache, std::unique_ptr, + VectorTensorShapeHasher> + cache_; +}; + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0aa5eb8f7d4ad062c2d8622fa5aa55f823f80dd5 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace tensorrt { + +TEST(LRUCacheTest, Basic) { + LRUCache> cache; + cache.reserve(2); + // Insert 10 + cache.emplace(10, 100); + EXPECT_EQ(cache.size(), 1); + EXPECT_EQ(cache.count(10), 1); + EXPECT_EQ(cache.at(10), 100); + EXPECT_EQ(cache.count(100), 0); + // Insert 20 + cache.emplace(20, 200); + EXPECT_EQ(cache.size(), 2); + EXPECT_EQ(cache.count(10), 1); + EXPECT_EQ(cache.count(20), 1); + EXPECT_EQ(cache.at(10), 100); + EXPECT_EQ(cache.at(20), 200); + EXPECT_EQ(cache.count(100), 0); + EXPECT_EQ(cache.count(200), 0); + // Insert 30, Evicting 10 + cache.emplace(30, 300); + EXPECT_EQ(cache.count(10), 0); + EXPECT_EQ(cache.count(20), 1); + EXPECT_EQ(cache.count(30), 1); + // Touch 20 + cache.at(20); + // Insert 40, Evicting 30 + cache.emplace(40, 400); + EXPECT_EQ(cache.count(10), 0); + EXPECT_EQ(cache.count(20), 1); + EXPECT_EQ(cache.count(30), 0); + EXPECT_EQ(cache.count(40), 1); +} + +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e553079b19a3e5d0739cc6ac79a84f3b6a1fc4e --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +TRTCalibrationResource::~TRTCalibrationResource() { + VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); + builder_.reset(); + engine_.reset(); + // We need to manually destroy the builder and engine before the allocator + // is destroyed. + allocator_.reset(); +} + +string TRTCalibrationResource::DebugString() const { + std::stringstream oss; + using std::dec; + using std::endl; + using std::hex; + oss << " Calibrator = " << hex << calibrator_.get() << dec << endl + << " Builder = " << hex << builder_.get() << dec << endl + << " Engine = " << hex << engine_.get() << dec << endl + << " Logger = " << hex << &logger_ << dec << endl + << " Allocator = " << hex << allocator_.get() << dec << endl + << " Thread = " << hex << thr_.get() << dec << endl; + return oss.str(); +} + +Status TRTCalibrationResource::SerializeToString(string* serialized) { + calibrator_->waitAndSetDone(); + thr_->join(); + *serialized = calibrator_->getCalibrationTableAsString(); + if (serialized->empty()) { + return tensorflow::errors::Unknown("Calibration table is empty."); + } + return Status::OK(); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h new file mode 100644 index 0000000000000000000000000000000000000000..5e8d4b3b738df09b0c2ea82dcc06e9b23a708385 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +class SerializableResourceBase : public tensorflow::ResourceBase { + public: + virtual Status SerializeToString(string* serialized) = 0; +}; + +class TRTCalibrationResource : public SerializableResourceBase { + public: + ~TRTCalibrationResource() override; + + string DebugString() const override; + + Status SerializeToString(string* serialized) override; + + // Lookup table for temporary staging areas of input tensors for calibration. + std::unordered_map> device_buffers_; + + // Temporary staging areas for calibration inputs. + std::vector device_tensors_; + + std::unique_ptr calibrator_; + TrtUniquePtrType builder_; + TrtUniquePtrType engine_; + std::unique_ptr allocator_; + tensorflow::tensorrt::Logger logger_; + // TODO(sami): Use threadpool threads! + std::unique_ptr thr_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 5a0d9b9af9d55a8dee809d3cf909bce39c3b8b6c..7d9e7b9fc1f7ea83d6aa982afb5df097b0bdbf77 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1,6 +1,6 @@ licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test") package_group( name = "internal", @@ -24,7 +24,7 @@ package( ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") -load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library") cc_library( name = "tf2xla_supported_ops_lib", @@ -60,6 +60,14 @@ xla_proto_library( ], ) +xla_py_proto_library( + name = "tf2xla_py", + has_services = False, + api_version = 2, + visibility = ["//visibility:public"], + deps = [":tf2xla_proto"], +) + xla_proto_library( name = "host_compute_metadata_proto", srcs = ["host_compute_metadata.proto"], @@ -204,6 +212,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", @@ -224,6 +233,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", ], alwayslink = 1, ) @@ -244,6 +254,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -280,6 +291,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) @@ -314,11 +326,13 @@ tf_cc_test( ":tf2xla_util", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -443,6 +457,7 @@ cc_library( hdrs = [ "dump_graph.h", ], + visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/jit:flags", "//tensorflow/core:framework", @@ -668,8 +683,31 @@ cc_library( name = "side_effect_util", srcs = ["side_effect_util.cc"], hdrs = ["side_effect_util.h"], + visibility = [":friends"], deps = [ "//tensorflow/core:core_cpu", "@com_google_absl//absl/strings", ], ) + +tf_cuda_cc_test( + name = "fused_batchnorm_reserve_space_test", + size = "medium", + srcs = ["fused_batchnorm_reserve_space_test.cc"], + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/compiler/jit", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/algorithm:container", + ], +) diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime.h b/tensorflow/compiler/tf2xla/cpu_function_runtime.h index dfc1e8b8aebcf3142e9f61f60171c6b58634c71d..78970fb39bae7067c7668baa2aec65732b5b2352 100644 --- a/tensorflow/compiler/tf2xla/cpu_function_runtime.h +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.h @@ -104,7 +104,7 @@ class BufferInfo { private: BufferInfo() = default; - enum class Kind : unsigned { + enum class Kind : uint64 { kConstant, kTempBuffer, kEntryParameter, diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index c693e42d26712d55852f45c806215fc1f1b9a030..8aa162be47c9181e215de6a2eb660215135ff6eb 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -34,6 +34,8 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/strcat.h" using xla::StatusOr; @@ -41,6 +43,43 @@ using xla::StatusOr; namespace tensorflow { namespace functionalize_cond { +bool AncestorNode::operator<(const AncestorNode& other) const { + return (output_tensor.node->id() < other.output_tensor.node->id()) || + (output_tensor.node->id() == other.output_tensor.node->id() && + output_tensor.index < other.output_tensor.index) || + (output_tensor.node->id() == other.output_tensor.node->id() && + output_tensor.index == other.output_tensor.index && + type < other.type); +} + +bool AncestorNode::operator==(const AncestorNode& other) const { + return output_tensor.node->id() == other.output_tensor.node->id() && + output_tensor.index == other.output_tensor.index && type == other.type; +} + +size_t AncestorNode::Hash::operator()(const AncestorNode& ancestor) const { + size_t h = std::hash()(ancestor.output_tensor.node->id()); + h = Hash64Combine(h, std::hash()(ancestor.output_tensor.index)); + return Hash64Combine(h, std::hash()(static_cast(ancestor.type))); +} + +typedef std::tuple + ClusterTuple; + +struct ClusterTupleLessThan { + bool operator()(const ClusterTuple& a, const ClusterTuple& b) const { + if (std::tie(std::get<0>(a), std::get<1>(a)) < + std::tie(std::get<0>(b), std::get<1>(b))) { + return true; + } else if (std::tie(std::get<0>(a), std::get<1>(a)) == + std::tie(std::get<0>(b), std::get<1>(b))) { + return StateMap::OutputTensorLess()(std::get<2>(a), std::get<2>(b)); + } else { + return false; + } + } +}; + // TODO(jpienaar): Move to OutputTensor. string DebugString(const OutputTensor& tensor) { return absl::StrCat(tensor.node->name(), ":", tensor.index); @@ -145,10 +184,10 @@ size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const { if (map.empty()) return 0; // Compute hash of the front element. auto it = map.begin(); - size_t h = hash()(*it); + size_t h = AncestorNode::Hash()(*it); for (++it; it != map.end(); ++it) { // Combine the has with the different elements in the map. - h = Hash64Combine(h, hash()(*it)); + h = Hash64Combine(h, AncestorNode::Hash()(*it)); } return h; } @@ -229,7 +268,17 @@ string StateMap::CondStateToString(StateMap::CondId id) const { } string StateMap::AncestorStateToString(const Node* node) const { - if (auto id = LookupAncestorId(node)) return NodesToString(*id); + if (auto id = LookupAncestorId(node)) { + return absl::StrCat( + "{", + absl::StrJoin(*id, ",", + [](string* output, const AncestorNode& ancestor) { + absl::StrAppend(output, + ancestor.output_tensor.node->name(), + ":", ancestor.output_tensor.index); + }), + "}"); + } return "{}"; } @@ -247,7 +296,9 @@ class Conditional { Status AddMerge(Node* m); // Constructs an If node from the merge nodes. - Status BuildAndReplace(Graph* graph, FunctionLibraryDefinition* library); + Status BuildAndReplace( + Graph* graph, FunctionLibraryDefinition* library, + std::unordered_map* merge_to_replacement); private: // Extracts the then/else bodies: creates new graphs with the nodes @@ -262,10 +313,15 @@ class Conditional { Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library); // Adds input edges to If node. - Status AddInputEdges(Graph* graph); + Status AddInputEdges( + Graph* graph, + const std::unordered_map& merge_to_replacement); // Adds output edges from If node. - Status AddOutputEdges(Graph* graph); + // Record new output tensor for all Merge nodes in 'merge_to_replacement'. + Status AddOutputEdges( + Graph* graph, + std::unordered_map* merge_to_replacement); // Adds switch node that is part of this conditional. Status AddSwitch(Node* s); @@ -640,7 +696,8 @@ Status Conditional::ExtractBodies(Graph* graph) { Status Conditional::BuildIfNode(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "Build cond function for " << name(); - NodeDefBuilder builder(name(), "If", library); + NodeDebugInfo debug_info((*merges_.begin())->def()); + NodeDefBuilder builder(name(), "If", library, &debug_info); const string branch_name[] = {"else_branch", "then_branch"}; for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); @@ -704,9 +761,9 @@ Status Conditional::BuildIfNode(Graph* graph, } builder.Device(predicate_.node->assigned_device_name()); // Conditional should be the first input ... - builder.Input(NodeDefBuilder::NodeOut(predicate_.node->name(), - predicate_.index, - predicate_.node->output_type(0))); + builder.Input( + NodeDefBuilder::NodeOut(predicate_.node->name(), predicate_.index, + predicate_.node->output_type(predicate_.index))); // ... followed by the other inputs. builder.Input(inputs); @@ -719,12 +776,29 @@ Status Conditional::BuildIfNode(Graph* graph, return Status::OK(); } -Status Conditional::AddInputEdges(Graph* graph) { +Status Conditional::AddInputEdges( + Graph* graph, + const std::unordered_map& merge_to_replacement) { VLOG(2) << "AddInputEdges for " << if_node_->name(); int index = 0; // Add predicate input. - graph->AddEdge(const_cast(predicate_.node), predicate_.index, if_node_, - index++); + if (predicate_.node->IsMerge()) { + // If the predicate is a Merge node, we should not use Merge output as + // predicate. Instead, we should use the corresponding If output in + // 'merge_to_replacement'. Otherwise, this Conditional's If node is still + // connected to the predicate Merge node; and when we call + // DeleteReachableAndDeadNodes(), the predicate Merge node and this + // Conditional's If node will be removed. + auto iter = merge_to_replacement.find(predicate_.node); + if (iter == merge_to_replacement.end()) { + return errors::Internal("Cannot find replacement for Merge node ", + predicate_.node->name()); + } + graph->AddEdge(iter->second.node, iter->second.index, if_node_, index++); + } else { + graph->AddEdge(const_cast(predicate_.node), predicate_.index, + if_node_, index++); + } // Add function body inputs. for (auto& arg : cond_arg_nodes_) { if (arg.src_output == Graph::kControlSlot) { @@ -739,7 +813,9 @@ Status Conditional::AddInputEdges(Graph* graph) { return Status::OK(); } -Status Conditional::AddOutputEdges(Graph* graph) { +Status Conditional::AddOutputEdges( + Graph* graph, + std::unordered_map* merge_to_replacement) { VLOG(2) << "AddOutputEdges for " << if_node_->name(); int i = 0; for (Node* node : merges_) { @@ -763,6 +839,10 @@ Status Conditional::AddOutputEdges(Graph* graph) { graph->AddEdge(if_node_, i, dst, dst_input); } } + + // Record corresponding output tensor in 'merge_to_replacement'. + (*merge_to_replacement)[node] = OutputTensor{if_node_, i}; + ++i; } for (Node* n : external_control_outputs_) { @@ -772,8 +852,9 @@ Status Conditional::AddOutputEdges(Graph* graph) { return Status::OK(); } -Status Conditional::BuildAndReplace(Graph* graph, - FunctionLibraryDefinition* library) { +Status Conditional::BuildAndReplace( + Graph* graph, FunctionLibraryDefinition* library, + std::unordered_map* merge_to_replacement) { VLOG(1) << "Build If and replace merge nodes " << NodesToString(this->merges_); if (replaced_) return Status::OK(); @@ -792,8 +873,8 @@ Status Conditional::BuildAndReplace(Graph* graph, } TF_RETURN_IF_ERROR(BuildIfNode(graph, library)); - TF_RETURN_IF_ERROR(AddInputEdges(graph)); - TF_RETURN_IF_ERROR(AddOutputEdges(graph)); + TF_RETURN_IF_ERROR(AddInputEdges(graph, *merge_to_replacement)); + TF_RETURN_IF_ERROR(AddOutputEdges(graph, merge_to_replacement)); TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_)); // Check that the if_node doesn't feed into itself. @@ -935,6 +1016,10 @@ StatusOr FunctionalizeCond::JoinCondStatesMerge( VLOG(4) << "Joining (for merge) " << DebugString(src) << " and " << DebugString(dst); if (state_map_.IsEmpty(dst)) return src; + if (state_map_.IsEmpty(src)) { + return errors::Internal("Merge node ", merge->name(), + " has input that's not in any CondContext."); + } if (state_map_.IsDead(src)) return src; if (state_map_.IsDead(dst)) return dst; @@ -1169,8 +1254,17 @@ Status FunctionalizeCond::DetermineAncestorState(Node* dst) { if (other_id != id && other_id != nullptr) { state.insert(other_id->begin(), other_id->end()); } - if (IsSwitch(src) || IsMerge(src)) { - state.insert(src); + if (IsMerge(src)) { + state.insert({{src, 0}, AncestorNode::AncestorNodeType::kMerge}); + } else if (IsSwitch(src)) { + OutputTensor pred; + // For dead switch nodes, GetSwitchPredicate() will fail, and we use + // the switch node directly as ancestor. + if (GetSwitchPredicate(*src, &pred).ok()) { + state.insert({pred, AncestorNode::AncestorNodeType::kPred}); + } else { + state.insert({{src, 0}, AncestorNode::AncestorNodeType::kSwitch}); + } } return state_map_.GetAncestorId(state); }; @@ -1316,16 +1410,30 @@ Status FunctionalizeCond::FunctionalizeInternal() { // Sort the merge nodes from innermost outwards. SortMergeNodes(&merge_order); - // Cluster merge nodes by CondId and AncestorId in order of nesting. - using ClusterPair = std::pair; + // Cluster merge nodes by (CondId, AncestorId, predicate) in order of + // nesting. (CondId, AncestorId) is not enough, e.g. + // pred1 = array_ops.placeholder(dtypes.bool, name='pred1') + // pred2 = array_ops.placeholder(dtypes.bool, name='pred2') + // cond1 = control_flow_ops.cond(pred1, ...) + // cond2 = control_flow_ops.cond(pred2, ...) + // cond3 = control_flow_ops.cond(pred1, use cond1 and cond2) + // cond4 = control_flow_ops.cond(pred2, use cond1 and cond2) + // cond3 and cond4 have the same (CondId, AncestorId), but they should not + // be merged into one "If" node (because they have different predicates). std::deque> merge_clusters; - std::map merge_cluster_index; + std::map merge_cluster_index; for (Node* merge : merge_order) { auto cond_id = state_map_.LookupCondId(merge); if (state_map_.IsDead(cond_id)) continue; - ClusterPair key = - std::make_pair(cond_id, state_map_.LookupAncestorId(merge)); + auto predicate = merge_to_predicate_.find(merge); + if (predicate == merge_to_predicate_.end()) { + return errors::Internal("Cannot find predicate for Merge node ", + merge->name()); + } + + ClusterTuple key = std::make_tuple( + cond_id, state_map_.LookupAncestorId(merge), predicate->second); auto idx = merge_cluster_index.find(key); if (idx == merge_cluster_index.end()) { merge_cluster_index[key] = merge_clusters.size(); @@ -1344,7 +1452,8 @@ Status FunctionalizeCond::FunctionalizeInternal() { Conditional cond(merge_to_predicate_.at(cluster.front()), this, &state_map_); for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge)); - TF_RETURN_IF_ERROR(cond.BuildAndReplace(graph_, library_)); + TF_RETURN_IF_ERROR( + cond.BuildAndReplace(graph_, library_, &merge_to_replacement_)); if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract"); } diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index 8525d7af61b4471e53a9ae16b081060bfd234c9c..d85800fb8ee65a354716bf6601c6bc40eca9a10d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -43,6 +43,33 @@ enum class BranchType { kNeither = 3, }; +// When we keep track of which switch/merge node's feed into a node, we record +// 1) predicate for non-dead switch node, +// 2) the switch node itself for dead switch node, +// 3) the merge node itself for merge node. +// Case 1) is an optimization. With this optimization, if there are nodes from +// different switch nodes but those switch nodes have the same predicate, the +// nodes will still have same AncestorState, and they will be clustered into a +// single "If". +struct AncestorNode { + enum class AncestorNodeType { + kPred = 0, + kSwitch = 1, + kMerge = 2, + }; + + OutputTensor output_tensor; + AncestorNodeType type; + + // Compare two AncestorNodes by (node id, index, type). + bool operator<(const AncestorNode& other) const; + bool operator==(const AncestorNode& other) const; + + struct Hash { + size_t operator()(const AncestorNode&) const; + }; +}; + // StateMap is responsible for mapping from each graph Node to // * a CondState, where each CondState is a map from predicate to branch (i,e., // what predicates have to hold or not hold). @@ -68,7 +95,7 @@ class StateMap { using CondId = const CondState*; // Keep track of which switch/merge node's feed into a node's values. - using AncestorState = std::set; + using AncestorState = std::set; // Every unique ID is mapped to a AncestorState. using AncestorId = const AncestorState*; @@ -232,6 +259,9 @@ class FunctionalizeCond { // Mapping from merge nodes to predicate. std::unordered_map merge_to_predicate_; + // Mapping from merge nodes to corresponding If node outputs. + std::unordered_map merge_to_replacement_; + FunctionLibraryDefinition* library_; Graph* graph_; diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index b0aabd63bbda784b3b7103a438ce025eea0cd93b..05fa1ee92dc172bd11cec9f99e3884996e00791f 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -101,6 +101,17 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) { TF_EXPECT_OK(t.status()); } +TEST_F(FunctionalizeCondTest, JoinCondStatesMergeWithInputNotInCondContext) { + Tensor val_tensor(DT_INT32, TensorShape()); + val_tensor.flat().setZero(); + Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); + Node* m = test::graph::Merge(graph_.get(), val, val); + + StateMap::CondState cond_state; + auto joined_or = JoinCondStatesMerge(m, /*src=*/nullptr, &cond_state); + EXPECT_FALSE(joined_or.ok()); +} + } // namespace } // namespace functionalize_cond } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4535ece374ceb801e450af98a21d5a4c5e8f2a29 --- /dev/null +++ b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc @@ -0,0 +1,130 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace { +Status GetTestDevice(Session* session, string* test_device) { + std::vector devices; + TF_RETURN_IF_ERROR(session->ListDevices(&devices)); + + bool found_cpu = absl::c_any_of(devices, [&](const DeviceAttributes& device) { + return device.device_type() == "CPU"; + }); + + bool found_gpu = absl::c_any_of(devices, [&](const DeviceAttributes& device) { + return device.device_type() == "GPU"; + }); + + if (!found_gpu && !found_cpu) { + return errors::Internal("Expected at least one CPU or GPU!"); + } + + *test_device = found_gpu ? "GPU" : "CPU"; + VLOG(2) << "Using test device " << *test_device; + return Status::OK(); +} + +void FillZeros(Tensor* tensor) { + auto flat = tensor->flat(); + for (int i = 0; i < flat.size(); i++) { + flat.data()[i] = 0.0f; + } +} + +// This tests check that the implementation outputs from FusedBatchnorm +// training, reserve_space_{1|2}, are what we assume them to be in the TF/XLA +// lowering. +// +// If this test starts failing then it doesn't indicate that TF/cudnn have +// violated their contract, but it indicates that we need to update the TF/XLA +// lowering for FusedBatchnorm training to match the new implementation defined +// behavior. +TEST(FusedBatchnormReserveSpaceTest, Test) { + using ::tensorflow::ops::Const; + using ::tensorflow::ops::FusedBatchNorm; + + std::unique_ptr session( + tensorflow::NewSession(tensorflow::SessionOptions{})); + + string test_device; + TF_ASSERT_OK(GetTestDevice(session.get(), &test_device)); + + Scope root = tensorflow::Scope::NewRootScope(); + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + + Tensor scale_data(DT_FLOAT, TensorShape({10})); + FillZeros(&scale_data); + Output scale = + Const(root.WithOpName("scale"), Input::Initializer(scale_data)); + + Tensor offset_data(DT_FLOAT, TensorShape({10})); + FillZeros(&offset_data); + Output offset = + Const(root.WithOpName("offset"), Input::Initializer(offset_data)); + + Tensor mean_data(DT_FLOAT, TensorShape({0})); + Output mean = Const(root.WithOpName("offset"), Input::Initializer(mean_data)); + + Tensor variance_data(DT_FLOAT, TensorShape({0})); + Output variance = + Const(root.WithOpName("variance"), Input::Initializer(variance_data)); + + string tf_device = absl::StrCat("/device:", test_device, ":0"); + string xla_device = absl::StrCat("/device:XLA_", test_device, ":0"); + + FusedBatchNorm fused_batch_norm_tf( + root.WithOpName("fused_batch_norm_tf").WithDevice(tf_device), input, + scale, offset, mean, variance, FusedBatchNorm::Attrs{}.IsTraining(true)); + FusedBatchNorm fused_batch_norm_xla( + root.WithOpName("fused_batch_norm_xla").WithDevice(xla_device), input, + scale, offset, mean, variance, FusedBatchNorm::Attrs{}.IsTraining(true)); + + tensorflow::GraphDef graph; + TF_ASSERT_OK(root.ToGraphDef(&graph)); + + TF_ASSERT_OK(session->Create(graph)); + + Tensor input_data(DT_FLOAT, TensorShape({10, 10, 10, 10})); + auto flat_input = input_data.flat(); + for (int i = 0; i < flat_input.size(); i++) { + flat_input.data()[i] = (i - 5) / 1000.0f; + } + + std::vector results; + TF_ASSERT_OK(session->Run({{"input", input_data}}, + {fused_batch_norm_tf.reserve_space_1.name(), + fused_batch_norm_xla.reserve_space_1.name(), + fused_batch_norm_tf.reserve_space_2.name(), + fused_batch_norm_xla.reserve_space_2.name()}, + {}, &results)); + + test::ExpectClose(results[0], results[1], /*atol=*/1e-4); + test::ExpectClose(results[2], results[3], /*atol=*/1e-4); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index efb75749722893100494e089c0beb96944e9f1d4..5e4699bbb6218089d2e76a36c7351bf7fbd23264 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -88,6 +89,9 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, case XlaExpression::Kind::kResource: return errors::Unimplemented( "Resource as function argument is not yet implemented."); + case XlaExpression::Kind::kTensorList: + return errors::Unimplemented( + "TensorList as function argument is not yet implemented."); case XlaExpression::Kind::kInvalid: return errors::InvalidArgument("Invalid function argument"); } @@ -191,6 +195,9 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, // into the functions. XlaOpKernelContext xla_op_context(op_context); + XlaContext& context = XlaContext::Get(op_context); + auto* b = context.builder(); + XlaCompiler* compiler = xla_op_context.compiler(); NameAttrList func; @@ -219,8 +226,12 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, TF_RETURN_IF_ERROR( PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); + bool add_token_input_output = + HasNodeAttr(n->def(), kXlaTokenInputNodesAttrName); + XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = false; + compile_options.add_token_input_output = add_token_input_output; XlaCompiler::CompilationResult result; TF_RETURN_IF_ERROR( compiler->CompileFunction(compile_options, func, arguments, &result)); @@ -234,9 +245,19 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, } handles.push_back(expressions[i]->handle()); } - - XlaContext& context = XlaContext::Get(op_context); - auto* b = context.builder(); + if (add_token_input_output) { + std::vector token_input_nodes; + TF_RETURN_IF_ERROR( + GetNodeAttr(n->def(), kXlaTokenInputNodesAttrName, &token_input_nodes)); + std::vector token_inputs; + for (const string& node_name : token_input_nodes) { + auto token_or = compiler->GetNodeToken(node_name); + TF_RETURN_IF_ERROR(token_or.status()); + token_inputs.push_back(token_or.ConsumeValueOrDie()); + } + xla::XlaOp token_input = xla::AfterAll(b, token_inputs); + handles.push_back(token_input); + } auto output_handle = xla::Call(b, *result.computation, handles); // The output handle of `Call` computation is a tuple type. Unzip it so @@ -251,6 +272,10 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, ++computation_output; } } + if (add_token_input_output) { + TF_RETURN_IF_ERROR(compiler->SetNodeToken( + n->name(), xla::GetTupleElement(output_handle, computation_output))); + } return b->first_error(); } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 8bc329229648c5aced8d06c99b170803bb3a90f8..343568b2392595a2347bde41f0a2e2559fb1de19 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -1,16 +1,11 @@ +load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library") + licenses(["notice"]) # Apache 2.0 package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) -load("//tensorflow:tensorflow.bzl", "tf_copts") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -load( - "//third_party/mkl:build_defs.bzl", - "if_mkl", -) - tf_kernel_library( name = "xla_ops", srcs = [ @@ -39,6 +34,7 @@ tf_kernel_library( "dynamic_slice_ops.cc", "dynamic_stitch_op.cc", "elu_op.cc", + "empty_op.cc", "extract_image_patches_op.cc", "fake_param_op.cc", "fake_quantize_ops.cc", @@ -106,15 +102,18 @@ tf_kernel_library( "variable_ops.cc", "xla_broadcast_helper_op.cc", "xla_conv_op.cc", + "xla_dequantize_op.cc", "xla_dot_op.cc", "xla_pad_op.cc", "xla_reduce_op.cc", "xla_select_and_scatter_op.cc", + "xla_self_adjoint_eig_op.cc", ], hdrs = [ "index_ops.h", "shape_util.h", ], + tags = ["optonly"], deps = [ ":conv_op_helpers", ":if_op", @@ -122,12 +121,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:broadcast", - "//tensorflow/compiler/tf2xla/lib:cholesky", - "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/lib:util", - "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal", @@ -140,20 +136,38 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/client/lib:pooling", "//tensorflow/compiler/xla/client/lib:prng", + "//tensorflow/compiler/xla/client/lib:qr", + "//tensorflow/compiler/xla/client/lib:quantize", + "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", "//tensorflow/compiler/xla/client/lib:sorting", - "//tensorflow/compiler/xla/client/lib:triangular_solve", + "//tensorflow/core:bitwise_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", + "//tensorflow/core:data_flow_ops_op_lib", "//tensorflow/core:framework", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:linalg_ops_op_lib", + "//tensorflow/core:list_ops_op_lib", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:random_ops_op_lib", + "//tensorflow/core:resource_variable_ops_op_lib", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:sparse_ops_op_lib", "//tensorflow/core:spectral_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:stateless_random_ops_op_lib", + "//tensorflow/core:training_ops_op_lib", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:concat_lib", "//tensorflow/core/kernels:constant_op", diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 795ea09831e183a26fb3498b9bbaf9c3adaef9ed..5554d7a377d38554058aa731770ee10e400bc535 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -53,7 +53,11 @@ class XlaArgOp : public XlaOpKernel { const XlaExpression& arg = ctx->xla_context()->args()[index_]; OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid, errors::InvalidArgument("Invalid/missing argument expression")); - ctx->SetOutputExpression(0, arg); + if (ctx->expected_output_dtype(0) == DT_VARIANT) { + ctx->SetTensorListOutput(0, arg.handle()); + } else { + ctx->SetOutputExpression(0, arg); + } } private: @@ -63,6 +67,8 @@ class XlaArgOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp); }; -REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes().CompilationOnly(), XlaArgOp); +REGISTER_XLA_OP( + Name("_Arg").AllowResourceTypes().AllowVariantTypes().CompilationOnly(), + XlaArgOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 0e2f335f3354e3ae6008bdc0ac0b80683fe479c1..f1d78c87527eb5f818dcf92209feabe33653a625 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" @@ -34,6 +36,7 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); + is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT; } void Compile(XlaOpKernelContext* ctx) override { @@ -71,7 +74,18 @@ class FusedBatchNormOp : public XlaOpKernel { // variance to the gradient. Here we maintain the same behavior by setting // them to the mean and variance calculated by BatchNormTraining. ctx->SetOutput(3, xla::GetTupleElement(output, 1)); - ctx->SetOutput(4, xla::GetTupleElement(output, 2)); + if (is_on_gpu_) { + // The last two outputs from the FusedBatchNorm training TensorFlow GPU + // op are implementation defined. For now we rely on the in-practice + // behavior of the op: + // output 3 is the mean + // output 4 is rsqrt(variance + epsilon) + xla::XlaOp variance = xla::GetTupleElement(output, 2); + ctx->SetOutput(4, xla::Rsqrt(xla::Add( + variance, xla::ScalarLike(variance, epsilon_)))); + } else { + ctx->SetOutput(4, xla::GetTupleElement(output, 2)); + } } else { xla::XlaOp output = xla::BatchNormInference( input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), @@ -89,6 +103,7 @@ class FusedBatchNormOp : public XlaOpKernel { float epsilon_; TensorFormat data_format_; bool is_training_; + bool is_on_gpu_; }; REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp); @@ -104,6 +119,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); + is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT; } void Compile(XlaOpKernelContext* ctx) override { @@ -130,6 +146,22 @@ class FusedBatchNormGradOp : public XlaOpKernel { xla::XlaOp scale_backprop; xla::XlaOp offset_backprop; if (is_training_) { + if (is_on_gpu_) { + // The last two inputs to the FusedBatchNormGrad training TensorFlow GPU + // op are implementation defined. For now we rely on the in-practice + // behavior of the op: input 3 is the mean input 4 is rsqrt(variance + + // epsilon) + // + // The XLA op expects: + // input 3 is the mean + // input 4 is the variance + // + // so we adjust input 4 here. + xla::XlaOp one = xla::ScalarLike(var, 1.0f); + xla::XlaOp epsilon = xla::ScalarLike(var, epsilon_); + var = xla::Sub(one / (var * var), epsilon); + } + xla::XlaOp output = xla::BatchNormGrad(activations, scale, mean, var, grad_backprop, epsilon_, feature_index); @@ -158,9 +190,8 @@ class FusedBatchNormGradOp : public XlaOpKernel { offset_backprop = XlaHelpers::ConvertElementType(reduce, scale_dtype); // scratch1 = rsqrt(pop_var + epsilon) - auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); - auto scratch1 = xla::Pow( - xla::Add(var, xla::ConstantR0(b, epsilon_)), neg_half); + auto epsilon = XlaHelpers::FloatLiteral(b, scale_dtype, epsilon_); + auto scratch1 = xla::Rsqrt(xla::Add(var, epsilon)); // scratch2 = sum(y_backprop * (x - mean)) auto mul = @@ -187,6 +218,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { TensorFormat data_format_; float epsilon_; bool is_training_; + bool is_on_gpu_; }; REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp); diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 46e5d68c78fd9ff26a88dc2a1484c3a67b76f4f3..6b675fa8a94e0bc932baaa359565cbc8e4614ee5 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -39,7 +39,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, OP_REQUIRES( ctx, - xla::ShapeUtil::Rank(crops.shape()) == 2 && + crops.shape().rank() == 2 && block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) && 2 == xla::ShapeUtil::GetDimension(crops.shape(), 1), errors::InvalidArgument("crops should have shape [", block_rank, diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index e7f369b761f36a717ea5fb536780af91a8955b1e..33bdf9aec3167b0277f3c1db18c9e247ed9bb5d1 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -48,8 +48,11 @@ class BiasOp : public XlaOpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsVector(bias_shape), errors::InvalidArgument("Biases must be 1D: ", bias_shape.DebugString())); - int feature_dim = (data_format_ == FORMAT_NHWC) ? input_shape.dims() - 1 - : input_shape.dims() - 3; + + // feature_dim is the channel (C) dimension of the data. + int feature_dim = (data_format_ == FORMAT_NHWC) + ? input_shape.dims() - 1 + : /*data_format == FORMAT_NCHW*/ 1; OP_REQUIRES( ctx, feature_dim >= 0, errors::InvalidArgument("Input tensor does not have enough dimensions " @@ -91,9 +94,10 @@ class BiasAddGradOp : public XlaOpKernel { errors::InvalidArgument("Input tensor must be at least 2D: ", out_backprop_shape.DebugString())); + // feature_dim is the channel (C) dimension of the data. int feature_dim = (data_format_ == FORMAT_NHWC) ? out_backprop_shape.dims() - 1 - : out_backprop_shape.dims() - 3; + : /*data_format == FORMAT_NCHW*/ 1; OP_REQUIRES( ctx, feature_dim >= 0, errors::InvalidArgument("Input tensor does not have enough dimensions " diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 5e9280c1fe692037b0a842a92ef5a8c28b854a54..ad6b334326a470442c8c0d79b725345d4165be10 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -20,7 +20,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -165,12 +167,8 @@ XLA_MAKE_BINARY( xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), lhs, extend_dimensions)); -static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) { - return xla::Mul(x, x); -} - XLA_MAKE_BINARY(SquaredDifference, - Square(b, xla::Sub(lhs, rhs, extend_dimensions))); + xla::Square(xla::Sub(lhs, rhs, extend_dimensions))); XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions)); @@ -195,8 +193,8 @@ XLA_MAKE_BINARY(SoftplusGrad, // softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2 XLA_MAKE_BINARY(SoftsignGrad, xla::Div(lhs, - Square(b, xla::Add(XlaHelpers::One(b, input_type(0)), - xla::Abs(rhs))))); + xla::Square(xla::Add(XlaHelpers::One(b, input_type(0)), + xla::Abs(rhs))))); XLA_MAKE_BINARY(TanhGrad, xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)), @@ -204,6 +202,8 @@ XLA_MAKE_BINARY(TanhGrad, XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(NextAfter, xla::NextAfter(lhs, rhs)); + #undef XLA_MAKE_BINARY class ApproximateEqualOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 8cc2479dd555380da7500abe6b2aca380110333b..ca2152d6c103e05c06809d85d9529720ff112217 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -19,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -31,6 +33,7 @@ class CastOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_)); OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(src_dtype_, &src_type_)); OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dst_dtype_, &dst_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_)); } void Compile(XlaOpKernelContext* ctx) override { @@ -48,6 +51,36 @@ class CastOp : public XlaOpKernel { // imaginary part. output = xla::ConvertElementType(xla::Real(input), dst_type_); } else { + if (use_truncation_) { + OP_REQUIRES( + ctx, + xla::primitive_util::IsFloatingPointType(src_type_) && + xla::primitive_util::IsFloatingPointType(dst_type_), + errors::Unimplemented("Truncate attribute is only " + "implemented for floating point datatypes.")); + int mantissa_difference = + xla::primitive_util::SignificandWidth(src_type_) - + xla::primitive_util::SignificandWidth(dst_type_); + OP_REQUIRES(ctx, mantissa_difference > 0, + errors::Unimplemented( + "Truncate attribute is only implemented in cases where " + "dst datatype " + "has fewer mantissa bits than the src datatype")); + int src_bitwidth = xla::primitive_util::BitWidth(src_type_); + + // Bitcast to same-width integer, mask off the LSBs, bitcast back to the + // source datatype. + int64 mask = ~((1L << mantissa_difference) - 1); + xla::PrimitiveType same_width_int = + xla::primitive_util::UnsignedIntegralTypeForBitWidth(src_bitwidth); + OP_REQUIRES(ctx, same_width_int != xla::PRIMITIVE_TYPE_INVALID, + errors::Unimplemented("Unexpected type bitwidth")); + input = xla::BitcastConvertType( + xla::And( + xla::BitcastConvertType(input, same_width_int), + ::tensorflow::IntegerLiteral(builder, same_width_int, mask)), + src_type_); + } output = xla::ConvertElementType(input, dst_type_); } @@ -57,6 +90,7 @@ class CastOp : public XlaOpKernel { protected: DataType src_dtype_, dst_dtype_; xla::PrimitiveType src_type_, dst_type_; + bool use_truncation_; TF_DISALLOW_COPY_AND_ASSIGN(CastOp); }; @@ -79,8 +113,8 @@ class BitcastOp : public XlaOpKernel { if (src_dtype_ == dst_dtype_) { output = input; } else { - // The only complex type in XLA is C64, so error out if the bitcast has a - // complex source or destination type and the bitcast is not trivial. + // Error out if the bitcast has a complex source or destination type and + // the bitcast is not trivial. OP_REQUIRES(ctx, !xla::primitive_util::IsComplexType(src_type_) && !xla::primitive_util::IsComplexType(dst_type_), diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 7199b9b6feb36dd45ef51f4c38463bc715fcc38a..a99c6ee4431852166eec0a71bb7ad74fd5c135d9 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -99,8 +100,8 @@ class CategoricalOp : public XlaOpKernel { xla::PrimitiveType xla_output_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_type(0), &xla_output_type)); - xla::XlaOp argmax = XlaHelpers::ArgMax(softmax_entries, xla_output_type, - /*axis=*/class_dimension); + xla::XlaOp argmax = xla::ArgMax(softmax_entries, xla_output_type, + /*axis=*/class_dimension); if (num_samples == 1) { argmax = xla::Reshape(argmax, {batch_size, 1}); } @@ -112,9 +113,12 @@ class CategoricalOp : public XlaOpKernel { xla::PrimitiveType type, XlaOpKernelContext* ctx) { xla::XlaBuilder* builder = ctx->builder(); - auto uniforms = - xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), - XlaHelpers::One(builder, input_type(0)), uniform_shape); + // We want a number in (0, 1) rather than [0, 1) or (0, 1]: + // * log(-log(0)) is ∞. + // * log(-log(1)) is -∞. + auto uniforms = xla::RngUniform( + xla::MinPositiveNormalValue(builder, type), + xla::One(builder, uniform_shape.element_type()), uniform_shape); return xla::Log(-xla::Log(uniforms)); } @@ -143,9 +147,13 @@ class StatelessCategoricalOp : public CategoricalOp { if (uniform_shape.element_type() == xla::BF16) { uniform_shape.set_element_type(xla::F32); } + // We want a number in (0, 1) rather than [0, 1) or (0, 1]: + // * log(-log(0)) is ∞. + // * log(-log(1)) is -∞. auto uniforms = xla::StatelessRngUniform( - {seed0, seed1}, uniform_shape, XlaHelpers::Zero(builder, DT_FLOAT), - XlaHelpers::One(builder, DT_FLOAT)); + {seed0, seed1}, uniform_shape, + xla::MinPositiveNormalValue(builder, uniform_shape.element_type()), + xla::One(builder, uniform_shape.element_type())); return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type); } diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index 9fcbc86adc0967cbb7fb73da8bdabc58b60953da..0ed3044efa5b1060d2b0ad2d5563b0e02ebf66ec 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/cholesky.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" namespace tensorflow { namespace { @@ -24,7 +24,7 @@ class CholeskyOp : public XlaOpKernel { public: explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - ctx->SetOutput(0, Cholesky(ctx->Input(0))); + ctx->SetOutput(0, xla::Cholesky(ctx->Input(0))); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index cd7c7f4a82df7a65829787efcb1fd2f77870e945..91e4d9cea7cbf6075e30250587044174c4b8e7f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -24,13 +24,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index dff8af800229b9605bb93e0498bc5e5cf012f244..ff6c54e47c62f0555ef045e25051f6ec5a3c1d39 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -83,6 +83,17 @@ class ConstOp : public XlaOpKernel { return; } break; + case DT_COMPLEX128: + if (proto_.scomplex_val_size() == 2) { + ctx->SetOutput( + 0, + xla::Broadcast(xla::ConstantR0( + b, xla::complex128(proto_.dcomplex_val(0), + proto_.dcomplex_val(1))), + shape.dim_sizes())); + return; + } + break; case DT_INT32: if (proto_.int_val_size() == 1) { ctx->SetOutput( diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 641fefafb357f6ad10483c454600f3dadd4f8cb7..e8b270c67a23b876612ab1dba92a8ae7a46a392d 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -26,13 +26,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/util/padding.h" @@ -203,7 +203,8 @@ Status ConvBackpropComputeDimensionsV2XlaShapes( StringPiece label, int num_spatial_dims, const xla::Shape& input_shape, const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape, absl::Span dilations, const std::vector& strides, - Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) { + Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims, + absl::Span explicit_paddings) { TensorShape input_tensor_shape, filter_tensor_shape, out_backprop_tensor_shape; TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); @@ -212,8 +213,8 @@ Status ConvBackpropComputeDimensionsV2XlaShapes( XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape)); return ConvBackpropComputeDimensionsV2( label, num_spatial_dims, input_tensor_shape, filter_tensor_shape, - out_backprop_tensor_shape, dilations, strides, padding, data_format, - dims); + out_backprop_tensor_shape, dilations, strides, padding, explicit_paddings, + data_format, dims); } } // anonymous namespace @@ -227,6 +228,10 @@ xla::StatusOr ConvOpAttrs::Create(int num_spatial_dims, TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations)); TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides)); TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding)); + if (attrs.padding == EXPLICIT) { + TF_RETURN_IF_ERROR( + ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings)); + } string data_format; TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format)); @@ -298,6 +303,11 @@ xla::StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, window_strides[i] = attrs.strides.at(dim); rhs_dilation[i] = attrs.dilations.at(dim); + if (attrs.padding == EXPLICIT) { + padding[i] = {attrs.explicit_paddings.at(dim * 2), + attrs.explicit_paddings.at(dim * 2 + 1)}; + } + int64 unused_output_size; TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( input_shape.dimensions(dim), filter_shape.dimensions(i), @@ -332,7 +342,7 @@ xla::StatusOr MakeXlaBackpropInputConvOp( TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding, - attrs.data_format, &dims)); + attrs.data_format, &dims, attrs.explicit_paddings)); // The input gradients are computed by a convolution of the output // gradients and the filter, with some appropriate padding. See the @@ -392,23 +402,31 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( builder->GetShape(activations)); TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, builder->GetShape(gradients)); + xla::XlaOp filter_backprop; + + xla::Shape input_shape = activations_shape; + xla::Shape output_shape = out_backprop_shape; + + TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape)); + const xla::Shape expanded_filter_shape = attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) : filter_shape; - // Reuse dimension computation logic from conv_grad_ops.cc. ConvBackpropDimensions dims; - TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( - type_string, attrs.num_spatial_dims, activations_shape, - expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, - attrs.padding, attrs.data_format, &dims)); - // The filter gradients are computed by a convolution of the input // activations and the output gradients, with some appropriate padding. // See the comment at the top of conv_grad_ops.h for details. - xla::ConvolutionDimensionNumbers dnums; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, activations_shape, + expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, + attrs.padding, attrs.data_format, &dims, attrs.explicit_paddings)); + // The activations (inputs) form the LHS of the convolution. // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] // For the gradient computation, we flip the roles of the batch and @@ -420,6 +438,14 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + bool use_batch_group_count = + filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise; + + std::vector> padding(attrs.num_spatial_dims); + std::vector rhs_dilation(attrs.num_spatial_dims); + std::vector window_strides(attrs.num_spatial_dims); + std::vector ones(attrs.num_spatial_dims, 1); + // Swap n_dim and c_dim in the activations. dnums.set_input_batch_dimension(c_dim); dnums.set_input_feature_dimension(n_dim); @@ -430,28 +456,32 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( dnums.set_kernel_input_feature_dimension(n_dim); dnums.set_kernel_output_feature_dimension(c_dim); - std::vector> padding(attrs.num_spatial_dims); - std::vector rhs_dilation(attrs.num_spatial_dims); - std::vector window_strides(attrs.num_spatial_dims); - std::vector ones(attrs.num_spatial_dims, 1); + // The dimension swap below is needed because filter shape is KH,KW,F,DM. + if (use_batch_group_count) { + dnums.set_output_batch_dimension(attrs.num_spatial_dims + 1); + dnums.set_output_feature_dimension(attrs.num_spatial_dims); + } else { + dnums.set_output_batch_dimension(attrs.num_spatial_dims); + dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1); + } // Tensorflow filter shape is [ H, W, ..., inC, outC ]. for (int i = 0; i < attrs.num_spatial_dims; ++i) { dnums.add_output_spatial_dimensions(i); } - dnums.set_output_batch_dimension(attrs.num_spatial_dims); - dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1); - for (int i = 0; i < attrs.num_spatial_dims; ++i) { + for (int64 i = 0; i < attrs.num_spatial_dims; ++i) { int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(dim); + rhs_dilation[i] = dims.spatial_dims[i].stride; + window_strides[i] = attrs.dilations[dim]; // We will also need to pad the input with zeros such that after the // convolution, we get the right size for the filter. // The padded_in_rows should be such that when we convolve this with the // expanded_out_rows as a filter, we should get filter_rows back. - // + const int64 padded_in_size = dims.spatial_dims[i].expanded_output_size + (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim]; @@ -472,6 +502,8 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( // We apply negative padding in this case. const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; + // + For the EXPLICIT padding, we pad the top/left side with the explicit + // padding and pad the bottom/right side with the remaining space. // + For the VALID padding, we don't pad anything on the top/left side // and pad the bottom/right side with the remaining space. // + For the SAME padding, we pad top/left side the same as bottom/right @@ -480,12 +512,12 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( // In addition, if the padded input size is smaller than the input size, // we need to ignore some training elements of the input. We do this by // applying negative padding on the right/bottom. - const int64 pad_before = - attrs.padding == Padding::SAME ? std::max(pad_total / 2, 0) : 0; - + const int64 pad_before = attrs.padding == Padding::EXPLICIT + ? attrs.explicit_paddings[2 * dim] + : attrs.padding == Padding::SAME + ? std::max(pad_total / 2, 0) + : 0; padding[i] = {pad_before, pad_total - pad_before}; - rhs_dilation[i] = dims.spatial_dims[i].stride; - window_strides[i] = attrs.dilations[dim]; } // Besides padding the input, we will also expand output_rows to @@ -496,11 +528,14 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( // // This is done by specifying the window dilation factors in the // convolution HLO below. - auto filter_backprop = - xla::ConvGeneralDilated(activations, gradients, window_strides, padding, - /*lhs_dilation=*/ones, rhs_dilation, dnums); - if (attrs.depthwise) { + filter_backprop = xla::ConvGeneralDilated( + activations, gradients, window_strides, padding, /*lhs_dilation=*/ones, + rhs_dilation, dnums, + /*feature_group_count=*/1, + /*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1); + + if (!use_batch_group_count && attrs.depthwise) { filter_backprop = ContractFilterForDepthwiseBackprop( filter_shape, filter_backprop, activations.builder()); } diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h index 6e1b70a47850ae5c05939f8dfb7ec129c031df21..d893eca7f9ba07dded76eb215af4779080fa66b9 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -47,6 +47,7 @@ struct ConvOpAttrs { std::vector dilations; std::vector strides; Padding padding; + std::vector explicit_paddings; TensorFormat data_format; }; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index eafdba876ae9e2c38694f065cf83bb3725b8460e..52c3c2c4a903a8c51f6b511774bc0312d39df826 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -25,13 +25,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/util/padding.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 6e6ba21daf5bf3eab5bfc15378e77b6dd253da7c..b119997cf39e210ed8e0ae730a08829e72b238b4 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/empty_op.cc b/tensorflow/compiler/tf2xla/kernels/empty_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..00d2ce7c12fdc96483612059d1c792c847df04f3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/empty_op.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Empty Op. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class EmptyOp : public XlaOpKernel { + public: + explicit EmptyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("init", &init_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + // The output of this Op is a tensor of shape 'shape' with each + // element set to the default value of 'dtype'. If 'init' is false then + // the result values may be left undefined, though we don't do that here. + const TensorShape shape_shape = ctx->InputShape("shape"); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(shape_shape), + errors::InvalidArgument("shape must be a vector of int32, got shape ", + shape_shape.DebugString())); + + std::vector shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("shape", &shape)); + + auto default_value = xla::Zero(ctx->builder(), type_); + auto result = xla::Broadcast(default_value, shape); + ctx->SetOutput(0, result); + } + + private: + DataType dtype_; + xla::PrimitiveType type_; + bool init_; +}; + +REGISTER_XLA_OP(Name("Empty").CompileTimeConstantInput("shape"), EmptyOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 6df8b5367d2390e65995beb1583b225755e6ee9f..a623585aad3b1b8f1f096ca527e7694d74f1ba46 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -21,12 +21,12 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/util/padding.h" diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 20b0de193dc060197f3062d3be0b8d45f7dcb9b1..6472045265e4d930a5da770a68f5c502192201ae 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -168,13 +167,13 @@ class GatherOp : public XlaOpKernel { OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis)); const auto params_dims = input_shape.dims(); - if (axis < 0) { - axis += params_dims; - } OP_REQUIRES( - context, 0 <= axis && axis < params_dims, + context, -params_dims <= axis && axis < params_dims, errors::InvalidArgument("Expected axis in the range [", -params_dims, ", ", params_dims, "), but got ", axis)); + if (axis < 0) { + axis += params_dims; + } } DataType index_type = input_type(1); diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index 19dd38c46ef154ea74bcbb6721dd04924702efcc..8b27e8e85a37bd5aa757b0cdd7e00e9fa3c0cf6e 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -38,9 +38,13 @@ class IdentityOp : public XlaOpKernel { // XLA_* devices also register a "real" Identity operator so we suppress the // dummy operator using CompilationOnly(). -REGISTER_XLA_OP(Name("Identity").AllowResourceTypes().CompilationOnly(), - IdentityOp); -REGISTER_XLA_OP(Name("IdentityN").AllowResourceTypes().CompilationOnly(), +REGISTER_XLA_OP( + Name("Identity").AllowResourceTypes().AllowVariantTypes().CompilationOnly(), + IdentityOp); +REGISTER_XLA_OP(Name("IdentityN") + .AllowResourceTypes() + .AllowVariantTypes() + .CompilationOnly(), IdentityOp); REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp); REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index b5e083912555c865b5eadc7697075c9ca4451ca9..aa5637e2669555da17af8bb05ab08beeba6a89c3 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -56,6 +56,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Building If: " << input_types_.size() << " inputs"; std::vector arguments(input_types_.size()); + int num_resource_args = 0; for (int i = 0; i < input_types_.size(); ++i) { XlaCompiler::Argument& arg = arguments[i]; DataType type = ctx->input_type(i + 1); @@ -79,14 +80,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.name = resource->name(); VLOG(2) << "Resource " << resource->name() << " type: " << DataTypeString(arg.type) - << " shape: " << arg.shape.DebugString() + << " shape: " << arg.HumanString() << " initialized: " << arg.initialized; + + num_resource_args++; } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = input_types_[i]; arg.shape = ctx->InputShape(i + 1); VLOG(2) << "Arg type: " << DataTypeString(arg.type) - << " shape: " << arg.shape.DebugString(); + << " shape: " << arg.HumanString(); } } @@ -147,12 +150,12 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); xla::Shape then_input_shape = then_result.xla_input_shapes[0]; - OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(then_input_shape), + OP_REQUIRES(ctx, then_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); OP_REQUIRES(ctx, else_result.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); xla::Shape else_input_shape = else_result.xla_input_shapes[0]; - OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(else_input_shape), + OP_REQUIRES(ctx, else_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); OP_REQUIRES(ctx, xla::ShapeUtil::Compatible(then_input_shape, else_input_shape), @@ -236,12 +239,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { ctx->SetOutput(i, output_handle); } if (has_token_input_output_) { - // Set token output for this "if" op. + // Set token output for this "If" op. Token output is the last output of + // XLA computation, which comes after all "normal" TF outputs and resource + // updates. For "If" node, num of resource updates equals to number of + // resource args because we set `return_updated_values_for_all_resources` + // to true in XlaCompiler option. xla::XlaOp token_output = - xla::GetTupleElement(outputs, output_types_.size()); + xla::GetTupleElement(outputs, output_types_.size() + num_resource_args); auto shape_or = b->GetShape(token_output); OP_REQUIRES_OK(ctx, shape_or.status()); - OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + OP_REQUIRES(ctx, shape_or.ValueOrDie().IsToken(), errors::FailedPrecondition( "Token output is not token type: ", xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index e9bb0a77e99d144863b027bd214081316d61c314..92b20fe0ba5611ca5314cd954026f7b71ea75f84 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -13,18 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { @@ -185,19 +187,20 @@ class AdjustContrastOpV2 : public XlaOpKernel { factor_shape.DebugString())); xla::XlaBuilder* b = context->builder(); - xla::XlaOp input = context->Input(0); - xla::XlaOp factor = context->Input(1); - DataType type = context->input_type(0); + xla::XlaOp input = context->Input(0); + xla::XlaOp factor = XlaHelpers::ConvertElementType(context->Input(1), type); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *context->GetOrCreateAdd(accumulation_type), {height_dim, width_dim}); - auto output = XlaHelpers::ConvertElementType(reduce, type); - output = - xla::Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); + + auto output = xla::Div( + reduce, XlaHelpers::FloatLiteral(b, accumulation_type, height * width)); + output = XlaHelpers::ConvertElementType(output, type); std::vector broadcast_dims(input_shape.dims() - 2); std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); @@ -233,8 +236,10 @@ class AdjustSaturationOp : public XlaOpKernel { channels, " channels.")); xla::XlaBuilder* b = context->builder(); - xla::XlaOp input = context->Input(0); - xla::XlaOp scale = context->Input(1); + xla::XlaOp input = + XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT); + xla::XlaOp scale = + XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT); DataType type = context->input_type(0); @@ -249,15 +254,17 @@ class AdjustSaturationOp : public XlaOpKernel { /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); - auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), - channel_shape); + auto hsv = + RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape); - hsv[1] = xla::Clamp(XlaHelpers::Zero(b, type), xla::Mul(hsv[1], scale), - XlaHelpers::One(b, type)); + hsv[1] = xla::Clamp(XlaHelpers::Zero(b, DT_FLOAT), xla::Mul(hsv[1], scale), + XlaHelpers::One(b, DT_FLOAT)); - auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); + auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT); - context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); + auto output = XlaHelpers::ConvertElementType( + xla::ConcatInDim(b, rgb, channel_dim), type); + context->SetOutput(0, output); } }; REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp); @@ -283,8 +290,10 @@ class AdjustHueOp : public XlaOpKernel { channels, " channels.")); xla::XlaBuilder* b = context->builder(); - xla::XlaOp input = context->Input(0); - xla::XlaOp delta = context->Input(1); + xla::XlaOp input = + XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT); + xla::XlaOp delta = + XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT); DataType type = context->input_type(0); @@ -299,20 +308,22 @@ class AdjustHueOp : public XlaOpKernel { /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); - auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), - channel_shape); + auto hsv = + RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape); - auto zero = XlaHelpers::Zero(b, type); - auto one = XlaHelpers::One(b, type); + auto zero = XlaHelpers::Zero(b, DT_FLOAT); + auto one = XlaHelpers::One(b, DT_FLOAT); auto& hue = hsv[0]; hue = xla::Rem(xla::Add(hsv[0], delta), one); hue = xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue); - auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); + auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT); - context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); + auto output = XlaHelpers::ConvertElementType( + xla::ConcatInDim(b, rgb, channel_dim), type); + context->SetOutput(0, output); } }; REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); @@ -351,24 +362,26 @@ struct SuppressBodyFn { auto num_outputs_so_far = values[1]; auto iou_mask = values[2]; auto included_iou = values[3]; - auto zero_r1 = xla::ConstantR1(builder, {0}); + auto zero = xla::ConstantR0(builder, 0); // Determine if current elem is active using a slice. - auto row_idx_r1 = xla::Reshape(row_idx, {1}); - auto active_elem = xla::DynamicSlice(included_iou, row_idx_r1, {1}); + // TODO(b/118437727): The only reason we need an explicit vector is because + // some old GCCs can't deduce the right type for MakeConstSpan, and + // providing a single-value initializer list directly uses the wrong + // overload. Delete this once the deprecated overload is gone. + std::vector row_idx_vector = {row_idx}; + auto active_elem = xla::DynamicSlice(included_iou, row_idx_vector, {1}); active_elem = xla::Reshape(active_elem, {}); // Increment output count iff current elem is not suppressed. num_outputs_so_far = xla::Select( active_elem, num_outputs_so_far + xla::ConstantR0(builder, 1), num_outputs_so_far); // Slice out the row_idx. - auto starts = xla::ConcatInDim(builder, {row_idx_r1, zero_r1}, 0); - auto row_iou = xla::DynamicSlice(iou_mask, starts, {1, num_boxes}); + auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes}); // Remove the diagonal from consideration. An elem cannot suppress // itself. - auto update_starts = xla::ConcatInDim(builder, {zero_r1, row_idx_r1}, 0); row_iou = xla::DynamicUpdateSlice( row_iou, xla::ConstantR2FromArray2D(builder, {{false}}), - update_starts); + {zero, row_idx}); // Create a suppression by inverting polarity. row_iou = xla::Reshape(row_iou, {num_boxes}); auto supp_mask = xla::Not(row_iou); @@ -505,9 +518,9 @@ class NonMaxSuppressionOp : public XlaOpKernel { init_values.push_back(included_iou); auto suppress_loop_result = - XlaWhileLoop(WhileCondFn(num_boxes, output_size), - SuppressBodyFn(num_boxes), init_values, "suppress_loop", - builder) + xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size), + SuppressBodyFn(num_boxes), init_values, + "suppress_loop", builder) .ValueOrDie(); xla::XlaOp included_score = diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 5a10c52ba8b6d4fab73f0dda67cbd52fd625e76b..d19d48e5dd95962fe4a4e4026eaf6b06b7898564 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -72,10 +73,10 @@ namespace { // from in_size to out_size. struct ResizeConvolutionDims { // Size of the kernel to use. - std::vector kernel_size; + std::vector kernel_size; // k // Stride of the convolution to use. - std::vector stride; + std::vector stride; // S }; ResizeConvolutionDims ComputeResizeConvolutionParameters( absl::Span in_size, absl::Span out_size, @@ -117,8 +118,10 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( // + dims.stride * (out_size - 1) int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, int64 stride) { - return (2 * kernel_size - 1) + (out_size - 1) * stride - (kernel_size - 1) - - 1 - (kernel_size * (in_size - 1)); + int64 padding = (2 * kernel_size - 1) + (out_size - 1) * stride - + (kernel_size - 1) - 1 - (kernel_size * (in_size - 1)); + + return padding; } // Form a 2D convolution kernel like: @@ -132,53 +135,100 @@ int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, // If the 2D kernel would be very large, the 1D kernel can be applied once in // each dimension due to the symmetry of the kernel along all axis to reduce the // computational intensity. -xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) { +xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder, + xla::PrimitiveType type, int64 n) { std::vector kernel(n * 2 - 1); for (int64 i = 0; i < n; ++i) { float v = (i + 1.0f) / n; kernel[i] = v; kernel[n * 2 - 2 - i] = v; } - return xla::ConstantR1(builder, kernel); + return xla::ConvertElementType(xla::ConstantR1(builder, kernel), type); +} + +// Unlike the bilinear kernel, which is triangular, the nearest neighbor +// kernel is a square. For example, a 1D kernel with n=3 would look like +// [0 1 1 1 0] +// and n=4 would look like +// [0 0 1 1 1 1 0]. +// Note that in the second case, the kernel is not symmetric and we default +// to the right (because an existing non TPU kernel +// for nearest neighbor resize already chose to default to the right, +// so we want to be consistent). +xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder, + xla::PrimitiveType type, int64 n) { + std::vector kernel(n * 2 - 1, 0.0f); + std::fill(&kernel[n / 2], &kernel[(3 * n) / 2], 1.0f); + + return xla::ConvertElementType(xla::ConstantR1(builder, kernel), type); } // Kernels with more than 16 spatial elements are considered intense and the -// kernel should applied to each dimension independently. +// kernel should be applied to each dimension independently. const int64 kMax2DKernelSize = 16; -xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, - absl::Span kernel_size, - int64 channels) { - auto depthwise_kernel = xla::Broadcast( - xla::Zero(builder, xla::F32), - {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}); - - return xla::Mul( - xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[1]), - /*broadcast_dimensions=*/{1}), - Make1DKernel(builder, kernel_size[0]), - /*broadcast_dimensions=*/{0}); -} +xla::XlaOp MakeGeneralResizeKernel(xla::XlaBuilder* builder, + xla::PrimitiveType type, + absl::Span kernel_size, + int64 channels, bool is_kernel_bilinear) { + auto make_kernel_func = + is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel; -xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, - absl::Span kernel_size, - int64 channels, int64 dim) { + std::vector depthwise_kernel_sizes = { + (2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}; auto depthwise_kernel = - xla::Broadcast(xla::Zero(builder, xla::F32), - {dim == 0 ? (2 * kernel_size[0] - 1) : 1, - dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}); - return xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[dim]), - /*broadcast_dimensions=*/{dim}); + xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[1]), + depthwise_kernel_sizes, /*broadcast_dimensions=*/{1}); + + return xla::Mul(depthwise_kernel, + make_kernel_func(builder, type, kernel_size[0]), + /*broadcast_dimensions=*/{0}); +} + +xla::XlaOp MakeGeneralResizeKernelInDim(xla::XlaBuilder* builder, + xla::PrimitiveType type, + absl::Span kernel_size, + int64 channels, int64 dim, + bool is_kernel_bilinear) { + auto make_kernel_func = + is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel; + + std::vector depthwise_kernel_sizes = { + dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}; + return xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[dim]), + depthwise_kernel_sizes, + /*broadcast_dimensions=*/{dim}); +} + +xla::XlaOp BroadcastSpatialDimensions(xla::XlaBuilder* builder, + const xla::XlaOp& input, + int32 spatial_dimensions_offset, + absl::Span in_size, + absl::Span out_size) { + // Add broadcasts to handle expanding from a size == 1 dimension to a + // size > 1 dimension. + auto broadcast_shape_or_status = builder->GetShape(input); + if (!broadcast_shape_or_status.ok()) { + return builder->ReportError(broadcast_shape_or_status.status()); + } + xla::Shape broadcast_shape = broadcast_shape_or_status.ValueOrDie(); + for (int32 i = 0; i < in_size.size(); ++i) { + if (in_size[i] == 1 && out_size[i] > 1) { + broadcast_shape.set_dimensions(spatial_dimensions_offset + i, + out_size[i]); + } + } + return xla::BroadcastInDim(input, broadcast_shape.dimensions(), + /*broadcast_dimensions=*/{0, 1, 2, 3}); } -xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, - const xla::XlaOp& input, - const int num_spatial_dims, - std::vector in_size, - std::vector out_size, - const int64 channels, - const bool align_corners) { - // Picture for a 1x3 to 1x4 resize: +xla::XlaOp ResizeUsingDilationAndConvolution( + xla::XlaBuilder* builder, const xla::XlaOp& input, xla::PrimitiveType type, + const int num_spatial_dims, absl::Span in_size, + absl::Span out_size, const int64 channels, + const bool align_corners, bool is_kernel_bilinear) { + // Picture for a 1x3 to 1x4 bilinear resize: // stride = 2, kernel size = 3 // Input: // 3 6 9 @@ -264,8 +314,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, // Split convolutions into independent dimensions if they would be a very // large kernel. if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { - xla::XlaOp kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size, + channels, is_kernel_bilinear); output = xla::ConvGeneralDilated(input_data, kernel, dims.stride, /*padding=*/ @@ -275,8 +325,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*rhs_dilation=*/{1, 1}, dimension_numbers, /*feature_group_count=*/channels); } else { - xla::XlaOp kernel0 = - MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim( + builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear); output = xla::ConvGeneralDilated( input_data, kernel0, {dims.stride[0], 1}, /*padding=*/ @@ -284,8 +334,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*lhs_dilation=*/{dims.kernel_size[0], 1}, /*rhs_dilation=*/{1, 1}, dimension_numbers, /*feature_group_count=*/channels); - xla::XlaOp kernel1 = - MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim( + builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear); output = xla::ConvGeneralDilated( output, kernel1, {1, dims.stride[1]}, /*padding=*/ @@ -297,22 +347,15 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, // Add broadcasts to handle expanding from a size == 1 dimension to a // size > 1 dimension. - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] == 1 && out_size[i] > 1) { - output = xla::Add(output, xla::ConstantR1(builder, out_size[i], 0), - /*broadcast_dimensions=*/{1 + i}); - } - } - return output; + return BroadcastSpatialDimensions( + builder, output, /*spatial_dimensions_offset=*/1, in_size, out_size); } -xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, - const xla::XlaOp& grad, - const int num_spatial_dims, - std::vector in_size, - std::vector grad_size, - const int64 channels, - const bool align_corners) { +xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( + xla::XlaBuilder* builder, const xla::XlaOp& grad, xla::PrimitiveType type, + const int num_spatial_dims, absl::Span in_size, + absl::Span grad_size, const int64 channels, + const bool align_corners, bool is_kernel_bilinear) { ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, grad_size, align_corners); @@ -332,19 +375,14 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); xla::XlaOp output; if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { - xla::XlaOp kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size, + channels, is_kernel_bilinear); // Broadcast the input kernel where the forward op expanded from a size == 1 // dimension to a size > 1 dimension. This has the effect of summing the // gradient contributions in that dimension. - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] == 1 && grad_size[i] > 1) { - kernel = - xla::Add(kernel, xla::ConstantR1(builder, grad_size[i], 0), - /*broadcast_dimensions=*/{i}); - } - } + kernel = BroadcastSpatialDimensions( + builder, kernel, /*spatial_dimensions_offset=*/0, in_size, grad_size); output = xla::ConvGeneralDilated( grad, kernel, /*window_strides=*/dims.kernel_size, @@ -355,23 +393,23 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, /*rhs_dilation=*/{1, 1}, dimension_numbers, /*feature_group_count=*/channels); } else { - xla::XlaOp kernel0 = - MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); - xla::XlaOp kernel1 = - MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); - - // Broadcast the input kernel where the forward op expanded from a size == 1 - // dimension to a size > 1 dimension. This has the effect of summing the - // gradient contributions in that dimension. + xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim( + builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear); + xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim( + builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear); + + // Broadcast the input kernel where the forward op expanded from a + // size == 1 dimension to a size > 1 dimension. This has the effect of + // summing the gradient contributions in that dimension. if (in_size[0] == 1 && grad_size[0] > 1) { - kernel0 = - xla::Add(kernel0, xla::ConstantR1(builder, grad_size[0], 0), - /*broadcast_dimensions=*/{0}); + kernel0 = BroadcastSpatialDimensions(builder, kernel0, + /*spatial_dimensions_offset=*/0, {1}, + {grad_size[0]}); } if (in_size[1] == 1 && grad_size[1] > 1) { - kernel1 = - xla::Add(kernel0, xla::ConstantR1(builder, grad_size[1], 0), - /*broadcast_dimensions=*/{1}); + kernel1 = BroadcastSpatialDimensions(builder, kernel0, + /*spatial_dimensions_offset=*/0, + in_size, grad_size); } output = xla::ConvGeneralDilated( @@ -402,114 +440,148 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, } } if (pad_output) { - output = xla::Pad(output, xla::ConstantR0(builder, 0.0f), padding); + output = xla::Pad(output, xla::Zero(builder, type), padding); } return output; } -class ResizeBilinearOp : public XlaOpKernel { - public: - explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); - } - - void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - - TensorShape input_shape = ctx->InputShape(0); - OP_REQUIRES(ctx, input_shape.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input_shape.DebugString())); - const int64 batch = input_shape.dim_size(0); - std::vector in_size = {input_shape.dim_size(1), - input_shape.dim_size(2)}; - const int64 channels = input_shape.dim_size(3); - OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, - errors::InvalidArgument("input size must be positive, got [", - in_size[0], ",", in_size[1], "]")); - - std::vector out_size; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size)); - OP_REQUIRES(ctx, out_size.size() == 2, - errors::InvalidArgument("output size must be length 2, got ", - out_size.size())); - OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0, - errors::InvalidArgument("output size must be positive, got [", - out_size[0], ",", out_size[1], "]")); - - const int num_spatial_dims = 2; - - xla::XlaOp input = ctx->Input(0); - - // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in - // dimension i. - bool slice_input = false; - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] > 1 && out_size[i] == 1) { - // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first - // entry before resizing. - slice_input = true; - in_size[i] = 1; - } - } - if (slice_input) { - input = - xla::Slice(input, {0, 0, 0, 0}, - {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); +void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_, + bool is_kernel_bilinear) { + xla::XlaBuilder* b = ctx->builder(); + + TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + // First dimension always assumed to be batch + const int64 batch = input_shape.dim_size(0); + std::vector in_size = {input_shape.dim_size(1), + input_shape.dim_size(2)}; + // Last/4th dimension always assumed to be num channels + const int64 channels = input_shape.dim_size(3); + OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, + errors::InvalidArgument("input size must be positive, got [", + in_size[0], ",", in_size[1], "]")); + + std::vector out_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size)); + OP_REQUIRES(ctx, out_size.size() == 2, + errors::InvalidArgument("output size must be length 2, got ", + out_size.size())); + OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0, + errors::InvalidArgument("output size must be positive, got [", + out_size[0], ",", out_size[1], "]")); + + const int num_spatial_dims = 2; + + xla::XlaOp input = ctx->Input(0); + xla::PrimitiveType input_type = ctx->input_xla_type(0); + + // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in + // dimension i. + bool slice_input = false; + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] > 1 && out_size[i] == 1) { + // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first + // entry before resizing. + slice_input = true; + in_size[i] = 1; } + } + if (slice_input) { + input = xla::Slice(input, {0, 0, 0, 0}, + {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); + } - // Output is always type float. + // Output is always type float if 'is_kernel_bilinear' is true. + if (is_kernel_bilinear) { input = xla::ConvertElementType(input, xla::F32); + input_type = xla::F32; + } - // Special Case: - // Instead of doing a ResizeUsingDilationAndConvolution directly, - // while (out_size[0]-1) = c * 2^x * (in_size[0]-1) for x>1 c>1, resize the - // image to 2*(in_size[0]-1)+1 x-times and then resize by scale c(int here). - // Instead of resizing directly we resize it iteratively. - // - // Since bilinear resize can be broken down as 2 sequential linear - // operations along different dimensions. - // Given sufficient numerical stability and a cxd is same as resizing axb -> exf -> cxd. - // This does not work in the case of align_corners_=false because of special - // padding requirements that cause multiple resizes to be very different - // from a single resize. - // - // This makes the convolutions kernels smaller and the operation faster. - xla::XlaOp output = input; - while (in_size != out_size) { - if (in_size[0] != 1 && in_size[1] != 1) { - std::vector k = { - (static_cast(out_size[0]) - 1) / ((in_size[0] - 1) * 2), - (static_cast(out_size[1]) - 1) / ((in_size[1] - 1) * 2)}; - if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && - k[0] > 1 && k[1] > 1 && align_corners_) { - std::vector next_out_size = {(in_size[0] - 1) * 2 + 1, - (in_size[1] - 1) * 2 + 1}; - output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, - in_size, next_out_size, - channels, align_corners_); - input = output; - in_size = next_out_size; - } else { - output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, - in_size, out_size, - channels, align_corners_); - in_size = out_size; - } + // Special Case: + // Instead of doing a ResizeUsingDilationAndConvolution directly, + // while (out_size[0]-1) = c * 2^x * (in_size[0]-1) for x>1 c>1, resize the + // image to 2*(in_size[0]-1)+1 x-times and then resize by scale c(int here). + // Instead of resizing directly we resize it iteratively. + // + // Since bilinear resize can be broken down as 2 sequential linear + // operations along different dimensions. + // Given sufficient numerical stability and a cxd is same as resizing axb -> exf -> cxd. + // This does not work in the case of align_corners_=false because of special + // padding requirements that cause multiple resizes to be very different + // from a single resize. + // + // This makes the convolutions kernels smaller and the operation faster. + xla::XlaOp output = input; + while (in_size != out_size) { + if (in_size[0] != 1 && in_size[1] != 1) { + std::vector k = { + (static_cast(out_size[0]) - 1) / ((in_size[0] - 1) * 2), + (static_cast(out_size[1]) - 1) / ((in_size[1] - 1) * 2)}; + if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && + k[0] > 1 && k[1] > 1 && align_corners_) { + std::vector next_out_size = {(in_size[0] - 1) * 2 + 1, + (in_size[1] - 1) * 2 + 1}; + output = ResizeUsingDilationAndConvolution( + b, input, input_type, num_spatial_dims, in_size, next_out_size, + channels, align_corners_, is_kernel_bilinear); + input = output; + in_size = next_out_size; } else { - output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, - in_size, out_size, channels, - align_corners_); + output = ResizeUsingDilationAndConvolution( + b, input, input_type, num_spatial_dims, in_size, out_size, channels, + align_corners_, is_kernel_bilinear); in_size = out_size; } + } else { + output = ResizeUsingDilationAndConvolution( + b, input, input_type, num_spatial_dims, in_size, out_size, channels, + align_corners_, is_kernel_bilinear); + in_size = out_size; } + } - ctx->SetOutput(0, output); + ctx->SetOutput(0, output); +} + +class ResizeNearestNeighborOp : public XlaOpKernel { + public: + explicit ResizeNearestNeighborOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES( + ctx, align_corners_ == true, + errors::Unimplemented("ResizeNearestNeighbor with align_corners=False " + "is not yet implemented")); + } + + void Compile(XlaOpKernelContext* ctx) override { + GeneralCompile(ctx, align_corners_, is_kernel_bilinear_); } private: - bool align_corners_; + bool align_corners_ = true; + bool is_kernel_bilinear_ = false; +}; + +REGISTER_XLA_OP(Name("ResizeNearestNeighbor").CompileTimeConstantInput("size"), + ResizeNearestNeighborOp); + +class ResizeBilinearOp : public XlaOpKernel { + public: + explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + GeneralCompile(ctx, align_corners_, is_kernel_bilinear_); + } + + private: + bool align_corners_ = true; + bool is_kernel_bilinear_ = true; }; REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstantInput("size"), @@ -580,20 +652,20 @@ class ResizeBilinearGradOp : public XlaOpKernel { std::vector next_grad_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, next_grad_size, channels, - align_corners_); + b, grad, xla::F32, num_spatial_dims, in_size, next_grad_size, + channels, align_corners_, true); grad = output; in_size = next_grad_size; } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels, - align_corners_); + b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels, + align_corners_, true); in_size = grad_size; } } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels, - align_corners_); + b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels, + align_corners_, true); in_size = grad_size; } } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 843b6bb4e658af16fd753c1a20b35dd3d18df027..c1539f48d4f729510b2d930de91666a7c31f1ef0 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -18,17 +18,16 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/index_ops.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { XlaArgMinMaxOp::XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min) @@ -66,9 +65,9 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { xla::XlaOp input = ctx->Input(0); xla::XlaOp output; if (is_min_) { - output = XlaHelpers::ArgMin(input, index_xla_type, axis); + output = xla::ArgMin(input, index_xla_type, axis); } else { - output = XlaHelpers::ArgMax(input, index_xla_type, axis); + output = xla::ArgMax(input, index_xla_type, axis); } ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index e2c05b648bb194b1b452c527ddb1a2c5995b1217..e4bbdef6480104a1051acfc647644deb65c80171 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -16,16 +16,16 @@ limitations under the License. // Native XLA implementations of indexing ops. #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { @@ -74,7 +74,7 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // shape isn't supported. if (!ctx->compiler()->options().allow_cpu_custom_calls || (input_dims != 1 && input_dims != 2)) { - xla::XlaOp output = XlaHelpers::ArgMax(ctx->Input(0), output_type, axis); + xla::XlaOp output = xla::ArgMax(ctx->Input(0), output_type, axis); ctx->SetOutput(0, output); return; } @@ -110,8 +110,8 @@ class ArgMaxCustomCallOp : public XlaOpKernel { auto shape_status = b.GetShape(arg); OP_REQUIRES_OK(ctx, shape_status.status()); xla::Shape arg_shape = shape_status.ConsumeValueOrDie(); - *arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout( - xla::ShapeUtil::Rank(arg_shape)); + *arg_shape.mutable_layout() = + xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank()); arg_shapes.push_back(std::move(arg_shape)); } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc index 47cf8c6675bc120653c2a5ab6d4b07376dc382ee..39d96e748b3a2a852c03c0dd53ec175f0c66a43a 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc @@ -25,9 +25,6 @@ limitations under the License. namespace tensorflow { EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) { - // data is managed by the JIT code so msan can't tell it's initialized. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 2 * sizeof(void*)); - float* input = static_cast(data[0]); int64 input_size = *static_cast(data[1]); diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 6440770c29894c951f010f6c1deb929f4fe79bbf..f36e0025250b3a196b31755a1ddf6620c415b6a3 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -24,8 +24,8 @@ limitations under the License. namespace tensorflow { namespace { -constexpr std::array kMatmulTypes = { - {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; +constexpr std::array kMatmulTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}}; class MatMulOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index 90c0ebefb24ec2c4378782e9b15d3f57c33032a4..5a6569c8954d1686dc9d7577a66feb720241ea13 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -15,7 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { namespace { @@ -31,7 +32,10 @@ class MatrixTriangularSolveOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { auto result = xla::TriangularSolve( ctx->Input(0), ctx->Input(1), /*left_side=*/true, - /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); + /*lower=*/lower_, /*unit_diagonal=*/false, + /*transpose_a=*/ + adjoint_ ? xla::TriangularSolveOptions::ADJOINT + : xla::TriangularSolveOptions::NO_TRANSPOSE); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index f6b8534f4d7c537e5b708ee000e00cb92123584b..656f9b898f32dfc05215014f51c2bbaf07580836 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -38,8 +38,7 @@ class MirrorPadOp : public XlaOpKernel { // - [1, 2, 3, 3, 2] in symmetric mode. int64 excluded_edges = mode == MirrorPadMode::REFLECT ? 1 : 0; xla::XlaOp accum = t; - for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; - --dimno) { + for (int64 dimno = original_shape.rank() - 1; dimno >= 0; --dimno) { auto t_rev = xla::Rev(accum, {dimno}); int64 lhs_padding = pad_literal.Get({dimno, 0}); int64 rhs_padding = pad_literal.Get({dimno, 1}); diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index a9b519d8928cc2807831fd6b4f12e60b7d58ea55..426a0941df57f19072d1cb9f3fa3d0079db465c5 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -24,12 +24,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 06c6cc37ec90192486ba15010bfeb763a9ffb987..23bb050a34d9246cdf73090aa6adfca054bf8bcf 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -26,10 +26,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/pooling_ops_common.h" diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc index 7ea0afc1f53cbe4cfcc3f6121a4ecd55864c1b52..66ec40a946b8a063d84acd33daf81f52ea2c35ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/qr.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" namespace tensorflow { namespace { @@ -26,7 +26,7 @@ class QROp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_)); } void Compile(XlaOpKernelContext* ctx) override { - auto result = QRDecomposition(ctx->Input(0), full_matrices_); + auto result = xla::QRDecomposition(ctx->Input(0), full_matrices_); if (!result.ok()) { ctx->SetStatus(result.status()); return; diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 8822e29f7e77b1cbc6fa6ca61d0062d9b1b0c36e..d6c70d4af1c2e921b70b0869f0163c8481017c7d 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -20,12 +20,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -160,23 +161,30 @@ class RandomShuffleOp : public XlaOpKernel { -> xla::StatusOr> { auto swaps = loop_vars[0]; auto indices = loop_vars[1]; - i = xla::Reshape(i, {1}); + // TODO(b/118437727): The absl::Span nonsense is only necessary because + // the deprecated overload creates ambiguity for the single-element span + // case. Remove it once the deprecated overload is gone. // temp = indices[i] - auto temp = xla::DynamicSlice(indices, i, {1}); + auto temp = + xla::DynamicSlice(indices, absl::Span({i}), {1}); // swap_index = swaps[i] - auto swap_index = xla::DynamicSlice(swaps, i, {1}); + auto swap_index = xla::Reshape( + xla::DynamicSlice(swaps, absl::Span({i}), {1}), {}); // swap_value = indices[swaps[i]] - auto swap_value = xla::DynamicSlice(indices, swap_index, {1}); + auto swap_value = xla::DynamicSlice( + indices, absl::Span({swap_index}), {1}); // indices[i] = indices[swaps[i]] - indices = xla::DynamicUpdateSlice(indices, swap_value, i); + indices = xla::DynamicUpdateSlice(indices, swap_value, + absl::Span({i})); // indices[swaps[i]] = temp - indices = xla::DynamicUpdateSlice(indices, temp, swap_index); + indices = xla::DynamicUpdateSlice( + indices, temp, absl::Span({swap_index})); return std::vector{swaps, indices}; }; // for i in range(n): auto swap_loop_result = - XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, - "indices_swap_loop", builder) + xla::ForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) .ValueOrDie(); auto swapped_indices = swap_loop_result[1]; @@ -272,9 +280,9 @@ class TruncatedNormalOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype, 1.0); + xla::XlaOp one = xla::One(b, xla_shape.element_type()); xla::XlaOp min_positive = - XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits::min()); + xla::MinPositiveNormalValue(b, xla_shape.element_type()); auto uniform = xla::RngUniform(min_positive, one, xla_shape); ctx->SetOutput(0, TruncatedNormal(uniform)); } diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index 54d34a38abc4948a1a08197d72e3e7f763649093..f9985d526033ca675c701a508a3d1576e46bc5f7 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -125,7 +125,7 @@ XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices, dimensions.back() = 1; auto batch_indices = - xla::Iota(b, xla::ShapeUtil::MakeShape(xla::U32, dimensions), + xla::Iota(b, xla::ShapeUtil::MakeShape(xla::S32, dimensions), /*iota_dimension=*/0); return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1); @@ -189,11 +189,53 @@ XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices, scatter_dim_numbers); } +// Bounds samples to 0 if the warp image indices are out of the (-1, image_size) +// bound. +// The resulting dimension is given by 'result_dims'. +XlaOp BoundSamples(XlaOpKernelContext* ctx, XlaOp warp, + xla::PrimitiveType warp_type, TensorShape warp_shape, + std::vector result_dims, + std::vector broadcasted_dims, int64 last_warp_dim, + xla::Shape data_shape, XlaOp sample) { + auto is_gt_minus_one = + xla::Gt(warp, + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {-1, -1}), warp_type), + /*broadcast_dimensions=*/{warp_shape.dims() - 1}); + auto is_lt_image_size = xla::Lt( + warp, + xla::ConvertElementType( + xla::ConstantR1( + ctx->builder(), + {/*width=*/static_cast(data_shape.dimensions(2)), + /*height=*/static_cast(data_shape.dimensions(1))}), + warp_type), + /*broadcast_dimensions=*/{warp_shape.dims() - 1}); + + auto is_in_bound_padded_x_y = xla::And(is_gt_minus_one, is_lt_image_size); + // Reduce along last dimension. The resulting dimension is: + // [batch, dim_0, ...dim_n]. + auto is_in_bound = xla::Reduce( + is_in_bound_padded_x_y, xla::ConstantR0(ctx->builder(), true), + xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, ctx->builder()), + {last_warp_dim}); + + // Broadcast 'is_in_bound' to the same dimension as 'result_dims'. + auto broadcasted_is_in_bound = + xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims); + + // Set out of bound samples to zero. + auto zeros = + xla::Broadcast(xla::Zero(ctx->builder(), warp_type), result_dims); + return xla::Select(broadcasted_is_in_bound, sample, zeros); +} + // Build computation the backprop into input 'data'. // Where input: // grad_output is of dimension [batch, dim_0, ...dim_n, channel] // ratio is of dimension [batch, dim_0, ...dim_n, 2] // gather_indices is of dimension [batch, dim_0, ...dim_n, 3] +// data_shape is of dimension [batch, x(width), y(height), channel] // // Output: // scatter-add to each 2x2 grad_data neighbor: @@ -201,10 +243,12 @@ XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices, // grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy // grad_data[fx, cy, chan] += output_grad * dx * (1 - dy) // grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy) -// where (dx, dy) is (1 - ratio). +// where (dx, dy) is (1 - ratio). If (dx, dy) is out of bound, then the their +// contribution is 0 to 'grad_data'. XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, - XlaOp gather_indices, xla::PrimitiveType warp_type, - TensorShape warp_shape, int64 data_channels, + XlaOp gather_indices, XlaOp warp, + xla::PrimitiveType warp_type, TensorShape warp_shape, + int64 last_warp_dim, int64 data_channels, xla::Shape data_shape) { // Weights tensor has dimension [batch, dim_0, ... dim_n, 4]. auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type); @@ -229,6 +273,18 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(), 0); + // Set out of bound weights to 0. + // The dimension of the reshaped_weight: [batch, dim_0, ...dim_n, 2, 2]. + std::vector reshaped_result_dims(warp_dims.begin(), + warp_dims.end() - 1); + reshaped_result_dims.push_back(2); + reshaped_result_dims.push_back(2); + std::vector broadcasted_dims(warp_dims.size() - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + reshaped_weights = BoundSamples(ctx, warp, warp_type, warp_shape, + reshaped_result_dims, broadcasted_dims, + last_warp_dim, data_shape, reshaped_weights); + // The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel]. auto broadcast_reshaped_weights = xla::BroadcastInDim( reshaped_weights, weights_with_channels_dims, reshaped_weights_indices); @@ -245,18 +301,41 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, auto grad_data = xla::ConstantLiteral( ctx->builder(), xla::Literal::CreateFromShape(data_shape)); - return ScatterToGradData(ctx, grad_data, gather_indices, - grad_output_multiply_weights, warp_shape.dims(), - warp_type); + // Pad grad data then slice it back. + // + // After left and right column 0-padding, the new dimension of padded data + // will be [batch, x+2, y+2, channel]. + auto padded_grad_data = + xla::Pad(grad_data, xla::Zero(ctx->builder(), warp_type), + xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); + + auto shifting_value = xla::ConstantR1( + ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); + auto shifted_gather_indices = + xla::Add(gather_indices, shifting_value, {last_warp_dim}); + + auto updated_grad_data = ScatterToGradData( + ctx, padded_grad_data, shifted_gather_indices, + grad_output_multiply_weights, warp_shape.dims(), warp_type); + + const int64 batch_size = data_shape.dimensions(0); + const int64 width = data_shape.dimensions(1); + const int64 height = data_shape.dimensions(2); + // Slice out the result accounting for the padding. + return xla::Slice( + updated_grad_data, /*start_indices=*/{0, 1, 1, 0}, + /*limit_indices=*/{batch_size, width + 1, height + 1, data_channels}, + /*strides=*/{1, 1, 1, 1}); } // Build computation for the backprop into input 'warp'. // Where input: -// warp is of dimension [batch, dim_0, ...dim_n, 2] -// grad_output is of dimension [batch, dim_0, ...dim_n, channel] -// ratio is of dimension [batch, dim_0, ...dim_n, 2] -// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] -// data is of dimension [batch, x, y, channel] +// warp is of dimension [batch, dim_0, ...dim_n, 2] +// grad_output is of dimension [batch, dim_0, ...dim_n, channel] +// ratio is of dimension [batch, dim_0, ...dim_n, 2] +// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] where the last +// dimension of size 3 is for {batch, x(width), y(height)}. +// data is of dimension [batch, x, y, channel] // // Output (simplified by ignoring the batch dimensions): // Since the forward path has: @@ -275,12 +354,12 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, // grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy) // grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy) // -// where (px, py) is warp, (fx, fy) is the left top corner and (cx, cy) is the +// where (px, py) is warp, (fx, fy) is the top left corner and (cx, cy) is the // bottom right corner in a 2x2 neighborhood. XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, XlaOp gather_indices, XlaOp data, TensorShape warp_shape, int64 data_channels, - xla::PrimitiveType data_type) { + xla::PrimitiveType data_type, xla::Shape data_shape) { auto warp_dims = warp_shape.dim_sizes(); std::vector warp_dims_without_last_dims(warp_dims.begin(), warp_dims.end() - 1); @@ -289,12 +368,30 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, std::vector neighbor_broadcast_dims = warp_dims_without_last_dims; neighbor_broadcast_dims.push_back(4); - // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] - auto neighbors_data = Gather2by2Neighbors( - ctx->builder(), data, gather_indices, data_channels, warp_shape.dims()); + // With dimension [batch, dim_0, ...dim_n, 4] + auto neighbor_broadcast_shape = + xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims); const int64 last_warp_dim = warp_shape.dims() - 1; + // Pad data with 0, before gathering such that 0 will be returned for samples + // in the range of (-1, 0) or (image_dimension-1, image_dimension). + // After left and right column 0-padding, the new dimension of padded data + // will be [batch, x+2, y+2, channel]. + auto padded_data = + xla::Pad(data, xla::Zero(ctx->builder(), data_type), + xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); + + auto shifting_value = xla::ConstantR1( + ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); + auto shifted_gather_indices = + xla::Add(gather_indices, shifting_value, {last_warp_dim}); + + // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] + auto neighbors_data = + Gather2by2Neighbors(ctx->builder(), padded_data, shifted_gather_indices, + data_channels, warp_shape.dims()); + // Since we will be creating the dot product of: // lhs: [batch, dim_0, ...dim_n, 4] // and @@ -417,7 +514,7 @@ class ResamplerOp : public XlaOpKernel { // Find the coordinates of the top left corner for the 2x2 region to be // sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the // last dimension of size 2 in turn is [x, y]. - XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + XlaOp top_left = xla::ConvertElementType(warp, xla::S32); auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); @@ -526,7 +623,8 @@ class ResamplerGradOp : public XlaOpKernel { size, "]")); } // Last dimension of warp shape must be of size 2. - OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2, + const int64 last_warp_dim = warp_shape.dims() - 1; + OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2, errors::InvalidArgument( "the last dimension of warp must be exactly size 2.")); xla::PrimitiveType warp_type = ctx->input_xla_type(1); @@ -549,24 +647,32 @@ class ResamplerGradOp : public XlaOpKernel { // Find the top left corner coordinate for the region to be sampled from. // The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension // of size 2 in turn is [x, y]. - XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + XlaOp top_left = xla::ConvertElementType(xla::Floor(warp), xla::S32); - // Dimensions are [batch, dim_0, ... dim_n, 2] + // Dimensions are [batch, dim_0, ... dim_n, 2]. XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type); // Indices for gathering neighboring pixels. auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); - auto grad_data = - CalculateGradData(ctx, grad_output, ratio, gather_indices, warp_type, - warp_shape, data_channels, data_shape); + auto grad_data = CalculateGradData( + ctx, grad_output, ratio, gather_indices, warp, warp_type, warp_shape, + last_warp_dim, data_channels, data_shape); auto grad_warp = CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data, - warp_shape, data_channels, data_type); + warp_shape, data_channels, data_type, data_shape); + auto warp_dims = warp_shape.dim_sizes(); + std::vector result_dims(warp_dims.begin(), warp_dims.end() - 1); + result_dims.push_back(2); + std::vector broadcasted_dims(warp_dims.size() - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + auto grad_warp_bounded = + BoundSamples(ctx, warp, warp_type, warp_shape, result_dims, + broadcasted_dims, last_warp_dim, data_shape, grad_warp); ctx->SetOutput(0, grad_data); - ctx->SetOutput(1, grad_warp); + ctx->SetOutput(1, grad_warp_bounded); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index e4046c795577983bff1a8053743bf4d3a258e583..1f417037284c87753b219ea5ce1d4edce0ce6336 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -37,10 +37,14 @@ class RetvalOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const Tensor& input = ctx->op_kernel_context()->input(0); - OP_REQUIRES(ctx, input.dtype() == dtype_, - errors::InvalidArgument( - "Type mismatch: actual ", DataTypeString(input.dtype()), - " vs. expect ", DataTypeString(dtype_))); + // DT_VARIANT types represent Tensor Lists and are wrapped in a DT_UINT8 + // tensor so we skip the check here. + if (dtype_ != DT_VARIANT) { + OP_REQUIRES(ctx, input.dtype() == dtype_, + errors::InvalidArgument( + "Type mismatch: actual ", DataTypeString(input.dtype()), + " vs. expect ", DataTypeString(dtype_))); + } auto frame = ctx->call_frame(); if (frame) { // If 'frame' is non-null, this is an inner function call inside a JIT @@ -59,8 +63,9 @@ class RetvalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); }; -REGISTER_XLA_OP(Name("_Retval").AllowResourceTypes().CompilationOnly(), - RetvalOp); +REGISTER_XLA_OP( + Name("_Retval").AllowResourceTypes().AllowVariantTypes().CompilationOnly(), + RetvalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 4b9e1a578be2445091228953df7e5c5e82b42c28..daefdfc58a4957d9e685d25aa90da6218f2041ad 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -23,13 +23,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index a95e7adacf194ba6eb33cbeb56abe1a5a2479337..a1c18bed3f94008af8038f32324c79aa5b2abded 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -110,10 +110,16 @@ class ScatterNdOp : public XlaOpKernel { auto updates = context->Input(1); auto result = XlaScatter(buffer, updates, indices, - /*indices_are_vectors=*/true, /*combiner=*/{}, builder); + /*indices_are_vectors=*/true, /*combiner=*/Combine, builder); OP_REQUIRES_OK(context, result.status()); context->SetOutput(0, result.ValueOrDie()); } + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Add(x, y); + } }; REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstantInput("shape"), diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 9e4c57c9bf73369662274f6b783418e18ff860c2..aaf8c6075dd292e33e70683774a6c1bf374183e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index b1fa2915d59e4e5e2f2523e20e9a37898d087117..7a620d2a6518f8686ef570b33aac971d1dccb6c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -157,9 +157,11 @@ class LinSpaceOp : public XlaOpKernel { flat(0) = start; } else { const float step = (stop - start) / (num - 1); - for (int64 i = 0; i < num; ++i) { + for (int64 i = 0; i < num - 1; ++i) { flat(i) = start + step * i; } + // The last value in the sequence must be equal to stop. + flat(num - 1) = stop; } break; } @@ -171,9 +173,11 @@ class LinSpaceOp : public XlaOpKernel { flat(0) = start; } else { const double step = (stop - start) / (num - 1); - for (int64 i = 0; i < num; ++i) { + for (int64 i = 0; i < num - 1; ++i) { flat(i) = start + step * i; } + // The last value in the sequence must be equal to stop. + flat(num - 1) = stop; } break; } diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 12830816ec16c9797f0fe4d8f3f13f5a8176161d..280b68383c28d1b9d88f7b2ac0f8fab47244c05d 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -20,10 +20,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { @@ -91,14 +92,20 @@ class SizeOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); - const int64 size = input_shape.num_elements(); - OP_REQUIRES(ctx, FastBoundsCheck(size, std::numeric_limits::max()), + OP_REQUIRES(ctx, + FastBoundsCheck(input_shape.num_elements(), + std::numeric_limits::max()), errors::InvalidArgument("Size does not work for tensors > " "int32 max.")); Tensor size_constant(DT_INT32, TensorShape({})); - size_constant.scalar()() = static_cast(size); - - ctx->SetConstantOutput(0, size_constant); + const int rank = input_shape.dims(); + xla::XlaBuilder* builder = ctx->builder(); + auto size = xla::One(builder, xla::U32); + for (int64 i = 0; i < rank; ++i) { + size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i)); + } + size = xla::ConvertElementType(size, ctx->output_xla_type(0)); + ctx->SetOutput(0, size); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc index 76ea5f525598f511f295eb5a30f3cf603fbf57aa..b18e3f965c427aec456ce2b188dad79485df23cc 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/framework/bounds_check.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 622efac81766fc3ddaf538b58170f34fce06927a..52bed2670b4b8408e3b2f72b64bf370aea5325f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -39,7 +39,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, OP_REQUIRES( ctx, - xla::ShapeUtil::Rank(paddings.shape()) == 2 && + paddings.shape().rank() == 2 && block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) && 2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1), errors::InvalidArgument("paddings should have shape [", block_rank, diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 8e9e4daf99d3dd3b8e149e3f3e5f6c27665c0fcb..b6c96b1f582710e1cc39e6e1e0e800ef8170743d 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -24,13 +24,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -45,7 +45,7 @@ Status GetStackShape(xla::XlaBuilder* builder, XlaResource* resource, return shape_or_status.status(); } xla::Shape shape = shape_or_status.ValueOrDie(); - TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); + TF_RET_CHECK(shape.IsTuple()); return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), stack_shape); } @@ -146,9 +146,9 @@ class StackPushOp : public XlaOpKernel { xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); @@ -202,9 +202,9 @@ class StackPopOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}})); + std::vector start_indices(stack_shape.dims(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; auto slice_shape = stack_shape.dim_sizes(); slice_shape[0] = 1LL; diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 50653d7b3973b73d580cdeec5d71943b575d7cc9..17f067e0dfcf4f8b360ee6db934df3e373d5fdd1 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -218,8 +218,8 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); auto uniform = xla::StatelessRngUniform( {seed0, seed1}, xla_shape, - xla::ConstantR0(builder, std::numeric_limits::min()), - xla::ConstantR0(builder, 1.0)); + xla::MinPositiveNormalValue(builder, xla_shape.element_type()), + xla::One(builder, xla_shape.element_type())); auto output = TruncatedNormal(uniform); output = MaybeConvertF32ToBF16(output, dtype_); ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 10d990b3213ab882cf44a4df20a977633de3fdab..e8846fbe88fa2a75244398ef0f601fd74e80ec50 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -288,19 +288,21 @@ class StridedSliceAssignOp : public XlaOpKernel { xla::XlaOp rhs = ctx->Input(4); absl::InlinedVector dimensions_to_reverse; - absl::InlinedVector slice_begin, slice_dims; + absl::InlinedVector slice_begin; + absl::InlinedVector slice_dims; for (int i = 0; i < begin.size(); ++i) { - // TODO(phawkins): implement strides != 1 + // TODO(b/121179231): implement strides != 1 OP_REQUIRES( ctx, strides[i] == 1 || strides[i] == -1, errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); if (strides[i] > 0) { - slice_begin.push_back(begin[i]); + slice_begin.push_back(xla::ConstantR0(ctx->builder(), begin[i])); slice_dims.push_back(end[i] - begin[i]); } else { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. - slice_begin.push_back(end[i] + 1); + slice_begin.push_back( + xla::ConstantR0(ctx->builder(), end[i] + 1)); slice_dims.push_back(begin[i] - end[i]); dimensions_to_reverse.push_back(i); } @@ -311,14 +313,7 @@ class StridedSliceAssignOp : public XlaOpKernel { } rhs = xla::Reshape(rhs, slice_dims); - if (lhs_shape.dims() == 0) { - // TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix - // and remove this workaround. - lhs = rhs; - } else { - lhs = xla::DynamicUpdateSlice( - lhs, rhs, xla::ConstantR1(ctx->builder(), slice_begin)); - } + lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 939d7e19515a1cb41e3e23e9d1fa957ae09ecab7..77a3e5c001e1c715f23ae5148f94dae2faa81acf 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -27,13 +27,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -123,7 +123,8 @@ Status GetTensorArrayShape(const XlaResource* resource, xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, absl::Span update_dims, - const xla::XlaOp& start_indices, DataType dtype) { + absl::Span start_indices, + DataType dtype) { xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); xla::XlaOp sum = dtype == DT_BOOL ? xla::Or(current, update) : xla::Add(current, update); @@ -212,9 +213,9 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::XlaOp flow = ctx->Input(3); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); @@ -263,9 +264,9 @@ class TensorArrayReadOp : public XlaOpKernel { xla::XlaOp index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}})); + std::vector start_indices(ta_shape.dims(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; auto slice_shape = ta_shape.dim_sizes(); slice_shape[0] = 1LL; @@ -419,10 +420,10 @@ class TensorArrayScatterOp : public XlaOpKernel { auto slice = xla::Slice(value, value_starts, value_ends, value_strides); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto index = xla::Slice(indices, {i}, {i + 1}, {1}); - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + auto index = xla::Reshape(xla::Slice(indices, {i}, {i + 1}, {1}), {}); + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices, dtype_); } } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 64a24703ae1460abfedb6d9298e1e164076a199a..8958a48bc79dce91c41ab7d0a5fc0fbb401112ba 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ // XLA TensorList operators. +// Tensor lists are represented as tuple consisting of a pre-allocated list +// consisting of the tensors (and where dim 0 is the list index), along with a +// scalar telling us the current number of elements. #include #include @@ -23,15 +26,17 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -45,11 +50,64 @@ Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op, return shape_or_status.status(); } xla::Shape shape = shape_or_status.ValueOrDie(); - TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); + TF_RET_CHECK(shape.IsTuple()); return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), tensor_list_shape); } +class TensorListLengthOp : public XlaOpKernel { + public: + explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp tl = ctx->Input(0); + xla::XlaOp index = xla::GetTupleElement(tl, 1); + ctx->SetOutput(0, index); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TensorListLengthOp); +}; + +REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp); + +// Creates an empty list with size (leading_dim, *element_shape) if +// element_shape is known at compile time. Otherwise creates one with size +// (leading_dim, 0) which gets initialized later in `GetInitializedList`. +Status CreateZerosList(XlaOpKernelContext* ctx, int element_shape_index, + int64 leading_dim, DataType dtype, xla::XlaOp* list) { + TensorShape list_shape; + list_shape.AddDim(leading_dim); + xla::XlaOp element_shape_handle = ctx->Input(element_shape_index); + TF_ASSIGN_OR_RETURN( + bool is_element_shape_compile_time_const, + element_shape_handle.builder()->IsConstant(element_shape_handle)); + PartialTensorShape partial_element_shape; + if (is_element_shape_compile_time_const) { + TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape( + element_shape_index, &partial_element_shape)); + } + if (is_element_shape_compile_time_const && + partial_element_shape.IsFullyDefined()) { + TensorShape element_shape; + partial_element_shape.AsTensorShape(&element_shape); + list_shape.AppendShape(element_shape); + } else { + // If element_shape is not a compile time constant or if it is not fully + // defined we will have to wait for the first write call to fully allocate + // the array. + // TODO(srbs): We are using element_shape of [0] as a proxy to denote an + // uninitialized list. A better implementation may be to represent the + // list as a 3-tuple containining an explicit "initialized" flag. However, + // we would still need to create a dummy tensor for the first tuple + // element. + list_shape.AddDim(0); + } + *list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype), + list_shape.dim_sizes()); + return Status::OK(); +} + class TensorListReserveOp : public XlaOpKernel { public: explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -57,19 +115,15 @@ class TensorListReserveOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - TensorShape element_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); int64 num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); - TensorShape tensor_shape; - tensor_shape.AddDim(num_elements); - tensor_shape.AppendShape(element_shape); + xla::XlaOp list; + OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &list)); xla::XlaBuilder* b = ctx->builder(); - ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), - tensor_shape.dim_sizes()), - xla::ConstantR0(b, 0)})); + ctx->SetTensorListOutput( + 0, xla::Tuple(b, {list, xla::ConstantR0(b, num_elements)})); } private: @@ -85,19 +139,37 @@ REGISTER_XLA_OP(Name("TensorListReserve") class EmptyTensorListOp : public XlaOpKernel { public: - explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { - ctx->CtxFailure( + int64 max_num_elements; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements)); + OP_REQUIRES( + ctx, max_num_elements >= 0, errors::InvalidArgument("XLA compilation requires a fixed tensor list " - "size. Use TensorListReserve instead.")); + "size. Set the max number of elements.")); + + xla::XlaOp list; + OP_REQUIRES_OK(ctx, + CreateZerosList(ctx, 0, max_num_elements, dtype_, &list)); + + xla::XlaBuilder* b = ctx->builder(); + ctx->SetTensorListOutput( + 0, xla::Tuple(b, {list, xla::ConstantR0(b, 0)})); } private: + DataType dtype_; + TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp); }; -REGISTER_XLA_OP(Name("EmptyTensorList"), EmptyTensorListOp); +REGISTER_XLA_OP(Name("EmptyTensorList") + .CompileTimeConstantInput("element_shape") + .CompileTimeConstantInput("max_num_elements"), + EmptyTensorListOp); class TensorListElementShapeOp : public XlaOpKernel { public: @@ -139,6 +211,168 @@ class TensorListElementShapeOp : public XlaOpKernel { REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp); +class TensorListGetItemOp : public XlaOpKernel { + public: + explicit TensorListGetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp state = ctx->Input(0); + + TensorShape shape; + OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape)); + + xla::XlaOp ta = xla::GetTupleElement(state, 0); + xla::XlaOp index = ctx->Input(1); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + std::vector start_indices(shape.dims(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; + auto slice_shape = shape.dim_sizes(); + slice_shape[0] = 1LL; + + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); + // Remove the leading '1' dimension. + std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); + + ctx->SetOutput(0, xla::Reshape(read, value_shape)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListGetItemOp); +}; + +REGISTER_XLA_OP(Name("TensorListGetItem"), TensorListGetItemOp); + +class TensorListStackOp : public XlaOpKernel { + public: + explicit TensorListStackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp state = ctx->Input(0); + xla::XlaOp ta = xla::GetTupleElement(state, 0); + ctx->SetOutput(0, ta); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListStackOp); +}; + +REGISTER_XLA_OP(Name("TensorListStack"), TensorListStackOp); + +class TensorListFromTensorOp : public XlaOpKernel { + public: + explicit TensorListFromTensorOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape element_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &element_shape)); + + const TensorShape tensor_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, tensor_shape.dims() > 0, + errors::InvalidArgument("Input value must be at least a " + "vector but received shape: ", + tensor_shape.DebugString())); + const int num_elements = tensor_shape.dim_size(0); + + xla::XlaBuilder* b = ctx->builder(); + const xla::XlaOp tensor = ctx->Input(0); + + ctx->SetTensorListOutput( + 0, xla::Tuple(b, {tensor, xla::ConstantR0(b, num_elements)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListFromTensorOp); +}; + +REGISTER_XLA_OP( + Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"), + TensorListFromTensorOp); + +// Returns the 0'th element of `tuple` containing the list tensor if it has been +// initialized already else creates one lazily. This allows lazy initialization +// of the list on the first call to SetItem or PushBack. +Status GetInitializedList(XlaOpKernelContext* ctx, const xla::XlaOp& tuple, + const TensorShape& element_shape, DataType dtype, + xla::XlaOp* list) { + *list = xla::GetTupleElement(tuple, 0); + TensorShape list_shape; + TF_RETURN_IF_ERROR(GetTensorListShape(ctx->builder(), tuple, &list_shape)); + int64 leading_dim = list_shape.dim_size(0); + TensorShape list_element_shape = list_shape; + list_element_shape.RemoveDim(0); + // This checks for the lazy initialization contract set by CreateEmptyList. + // In TensorListReserve if the element_shape is not known at compile time, + // it creates a list with shape [leading_dim, 0]. + if (element_shape != list_element_shape) { + if (list_element_shape.num_elements() != 0) { + return errors::InvalidArgument( + "Invalid shape of value in TensorListSetItem. Expected: ", + list_element_shape.DebugString(), + " Actual: ", element_shape.DebugString()); + } + list_shape = element_shape; + list_shape.InsertDim(0, leading_dim); + *list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype), + list_shape.dim_sizes()); + } + return Status::OK(); +} + +class TensorListSetItemOp : public XlaOpKernel { + public: + explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp tl = ctx->Input(0); + TensorShape elem_shape = ctx->InputShape(2); + + xla::XlaOp list; + OP_REQUIRES_OK(ctx, GetInitializedList(ctx, tl, elem_shape, dtype_, &list)); + + xla::XlaOp index = ctx->Input(1); + xla::XlaOp value = ctx->Input(2); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; + + TensorShape slice_shape = elem_shape; + slice_shape.InsertDim(0, 1LL); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); + + ctx->SetTensorListOutput( + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices), + xla::GetTupleElement(tl, 1)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListSetItemOp); +}; + +REGISTER_XLA_OP(Name("TensorListSetItem"), TensorListSetItemOp); + class TensorListPushBackOp : public XlaOpKernel { public: explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -147,26 +381,27 @@ class TensorListPushBackOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp list = ctx->Input(0); + xla::XlaOp list_tuple = ctx->Input(0); TensorShape elem_shape = ctx->InputShape(1); - xla::XlaOp ta = xla::GetTupleElement(list, 0); - xla::XlaOp index = xla::GetTupleElement(list, 1); + xla::XlaOp list; + OP_REQUIRES_OK( + ctx, GetInitializedList(ctx, list_tuple, elem_shape, dtype_, &list)); + + xla::XlaOp index = xla::GetTupleElement(list_tuple, 1); xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); auto update = xla::Reshape(value, slice_shape.dim_sizes()); - // TODO(phawkins): We don't check the index is in bounds --- there is no - // error mechanism in XLA. - ctx->SetOutput( - 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), + ctx->SetTensorListOutput( + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices), index + xla::ConstantR0(b, 1)})); } @@ -197,20 +432,17 @@ class TensorListPopBackOp : public XlaOpKernel { index = index - xla::ConstantR0(b, 1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, shape.dims() - 1}})); - + std::vector start_indices(shape.dims(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; auto slice_shape = shape.dim_sizes(); slice_shape[0] = 1LL; - // TODO(phawkins): We don't check the index is in bounds --- there is no - // error mechanism in XLA. xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); - ctx->SetOutput(0, xla::Tuple(b, {ta, index})); + ctx->SetTensorListOutput(0, xla::Tuple(b, {ta, index})); ctx->SetOutput(1, xla::Reshape(read, value_shape)); } diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 960c1462ceb8c00a2d6c96564f6c985fd1caef0f..ceb762038009f7a3ff80d9ad4066af43d54a9e34 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -172,6 +172,65 @@ class ResourceApplyMomentum : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes), ResourceApplyMomentum); +class ResourceApplyKerasMomentum : public XlaOpKernel { + public: + explicit ResourceApplyKerasMomentum(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(2); + + TensorShape var_shape, accum_shape; + xla::XlaOp var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + TensorShape lr_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + TensorShape momentum_shape = ctx->InputShape(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape), + errors::InvalidArgument("momentum is not a scalar: ", + momentum_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp grad = ctx->Input(3); + xla::XlaOp momentum = ctx->Input(4); + + accum = accum * momentum - grad * lr; + if (use_nesterov_) { + // See https://github.com/tensorflow/tensorflow/pull/2798 for an + // explanation of the reparameterization used here. + var = var + accum * momentum - grad * lr; + } else { + var = var + accum; + } + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); + } + + private: + bool use_nesterov_; +}; +REGISTER_XLA_OP( + Name("ResourceApplyKerasMomentum").TypeConstraint("T", kFloatTypes), + ResourceApplyKerasMomentum); + class ResourceApplyAdagrad : public XlaOpKernel { public: explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -797,15 +856,12 @@ class ResourceApplyAdadelta : public XlaOpKernel { xla::XlaOp grad = ctx->Input(6); xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp neg_half = XlaHelpers::FloatLiteral(b, dtype_, -0.5); - xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); - xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); - accum = rho * accum + (one - rho) * xla::Pow(grad, two); - xla::XlaOp update = xla::Pow(accum_update + epsilon, half) * - xla::Pow(accum + epsilon, neg_half) * grad; - accum_update = rho * accum_update + (one - rho) * xla::Pow(update, two); + accum = rho * accum + (one - rho) * xla::Square(grad); + xla::XlaOp update = + xla::Sqrt(accum_update + epsilon) * xla::Rsqrt(accum + epsilon) * grad; + accum_update = rho * accum_update + (one - rho) * xla::Square(update); var = var - update * lr; OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index c9b324a243e4cc3ec64daa3ca0d285336a0d0154..4ac714306248302242902f20d45d2609ef2c7cd3 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -19,14 +19,15 @@ limitations under the License. // helper. #include "tensorflow/core/kernels/transpose_op.h" +#include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { @@ -128,29 +129,46 @@ class InvertPermutationOp : public XlaOpKernel { errors::InvalidArgument("permutation of nonnegative int32s " "must have <= int32 max elements")); - std::vector perm; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm)); - - int size = perm.size(); + auto e = ctx->InputExpression(0); + auto tensor_or_status = e.ResolveConstant(ctx->compiler()->client()); + OP_REQUIRES_OK(ctx, tensor_or_status.status()); + // If the input is a constant, we also want the output to be a constant. + // Some models rely on the result of InvertPermutation being a constant. + // TODO(b/32495713): Remove this when we can check whether Scatter is + // constant. Right now, we always assume it is non-constant because we don't + // check the embedded computation. + if (tensor_or_status.ValueOrDie().has_value()) { + std::vector perm; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm)); + + int size = perm.size(); + + std::vector output(size); + std::fill_n(output.data(), size, -1); + for (int i = 0; i < size; ++i) { + const int64 d = perm[i]; + OP_REQUIRES(ctx, FastBoundsCheck(d, size), + errors::InvalidArgument(d, " is not between 0 and ", size)); + OP_REQUIRES(ctx, output[d] == -1, + errors::InvalidArgument(d, " is duplicated in the input.")); + output[d] = i; + } - std::vector output(size); - std::fill_n(output.data(), size, -1); - for (int i = 0; i < size; ++i) { - const int64 d = perm[i]; - OP_REQUIRES(ctx, FastBoundsCheck(d, size), - errors::InvalidArgument(d, " is not between 0 and ", size)); - OP_REQUIRES(ctx, output[d] == -1, - errors::InvalidArgument(d, " is duplicated in the input.")); - output[d] = i; + ctx->SetOutput(0, xla::ConstantR1(ctx->builder(), output)); + } else { + auto indices = ctx->Input(0); + int size = ctx->InputShape(0).num_elements(); + auto iota = xla::Iota(ctx->builder(), xla::S32, size); + auto result = XlaScatter(iota, iota, indices, + /*indices_are_vectors=*/false, /*combiner=*/{}, + ctx->builder()); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()); } - - ctx->SetOutput(0, xla::ConstantR1(ctx->builder(), output)); } }; -REGISTER_XLA_OP(Name("InvertPermutation") - .TypeConstraint("T", DT_INT32) - .CompileTimeConstantInput("x"), +REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32), InvertPermutationOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index a0ea6422d732b00fc1b8cf855d9c9ad603b87c82..62b5cd32da59063f8ce07119fd085f91ec3a1bc4 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -65,11 +65,8 @@ XLAJIT_MAKE_UNARY(Exp, xla::Exp(x)); XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x)); XLAJIT_MAKE_UNARY(Floor, xla::Floor(x)); XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x)); -XLAJIT_MAKE_UNARY( - IsInf, - xla::Eq(xla::Abs(x), - xla::ScalarLike(x, std::numeric_limits::infinity()))); -XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x)); +XLAJIT_MAKE_UNARY(IsInf, xla::IsInf(x)); +XLAJIT_MAKE_UNARY(IsNan, xla::IsNan(x)); // Return 1/x XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x); XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x); @@ -92,8 +89,9 @@ xla::XlaOp Sigmoid(xla::XlaOp x) { } XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x)); -// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. -XLAJIT_MAKE_UNARY(Sign, xla::Sign(x)); +// Returns 0 if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. +XLAJIT_MAKE_UNARY(Sign, + xla::Select(xla::Ne(x, x), xla::ZerosLike(x), xla::Sign(x))); XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x)); // softplus(x) = log(1 + exp(x)) @@ -116,37 +114,11 @@ XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x)); XLAJIT_MAKE_UNARY(Real, xla::Real(x)); XLAJIT_MAKE_UNARY(Imag, xla::Imag(x)); +XLAJIT_MAKE_UNARY(Erf, xla::Erf(x)); +XLAJIT_MAKE_UNARY(Erfc, xla::Erfc(x)); #undef XLAJIT_MAKE_UNARY -// Erf/Erfc. For x in (-1, 1), the erf approximation is used; erfc polynomial -// is used outside of this range. -class ErfOp : public XlaOpKernel { - public: - explicit ErfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - void Compile(XlaOpKernelContext* ctx) override { - xla::XlaOp x = ctx->Input(0); - xla::XlaOp one = xla::ScalarLike(x, 1.0); - auto y = - xla::Select(xla::Gt(xla::Abs(x), one), one - xla::Erfc(x), xla::Erf(x)); - ctx->SetOutput(0, y); - } -}; -REGISTER_XLA_OP(Name("Erf"), ErfOp); - -class ErfcOp : public XlaOpKernel { - public: - explicit ErfcOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - void Compile(XlaOpKernelContext* ctx) override { - xla::XlaOp x = ctx->Input(0); - xla::XlaOp one = xla::ScalarLike(x, 1.0); - auto y = - xla::Select(xla::Lt(xla::Abs(x), one), one - xla::Erf(x), xla::Erfc(x)); - ctx->SetOutput(0, y); - } -}; -REGISTER_XLA_OP(Name("Erfc"), ErfcOp); - class LgammaOp : public XlaOpKernel { public: explicit LgammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index 8671632976023fded04c26a9780c1a67638b0916..2fc5619de737b8977e4249e4d2297a0303c339ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -24,12 +24,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 2c92a585f5679242d672d0402e617ff199b94f17..dfa09b16081e93ba843a1858e68e6ff756de20c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -291,5 +291,19 @@ class ResourceScatterNdAddOp : public ResourceScatterOp { }; REGISTER_XLA_OP(Name("ResourceScatterNdAdd"), ResourceScatterNdAddOp); +class ResourceScatterNdSubOp : public ResourceScatterOp { + public: + explicit ResourceScatterNdSubOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/true, + /*combiner=*/Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Sub(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterNdSub"), ResourceScatterNdSubOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index ce007fc04a818869686b9936a1607cee42665e87..f49da9683b3622bdda708cc305306baafa1639df 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -41,8 +41,7 @@ Status MakeXlaCompilerArgumentsFromInputs( *has_uninitialized_vars = false; *has_tensor_arrays = false; for (int i = 0; i < ctx->num_inputs(); ++i) { - VLOG(2) << " Input " << i - << " type: " << DataTypeString(ctx->input_type(i)) + VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i)) << " shape: " << ctx->InputShape(i).DebugString(); XlaCompiler::Argument& arg = (*args)[i]; DataType type = ctx->input_type(i); @@ -71,13 +70,20 @@ Status MakeXlaCompilerArgumentsFromInputs( arg.name = resource->name(); VLOG(2) << " resource " << resource->name() << " type: " << DataTypeString(arg.type) - << " shape: " << arg.shape.DebugString() + << " shape: " << arg.ShapeHumanString() << " initialized: " << arg.initialized; } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = ctx->input_type(i); - arg.shape = ctx->InputShape(i); + + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp handle = ctx->Input(i); + auto shape_or_status = builder->GetShape(handle); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + arg.shape = shape_or_status.ValueOrDie(); } } return Status::OK(); @@ -207,12 +213,12 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); xla::Shape body_input_shape = body.xla_input_shapes[0]; - OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(body_input_shape), + OP_REQUIRES(ctx, body_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); xla::Shape cond_input_shape = cond.xla_input_shapes[0]; - OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(cond_input_shape), + OP_REQUIRES(ctx, cond_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape) @@ -233,13 +239,22 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::ShapeUtil::HumanString(body_input_shape), " vs. ", xla::ShapeUtil::HumanString(body.xla_output_shape))); - xla::Shape expected_cond_output_shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::PRED, {})}); + xla::Shape expected_cond_output_shape_without_side_effect = + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::PRED, {})}); + xla::Shape expected_cond_output_shape_with_side_effect = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::PRED, {}), + xla::ShapeUtil::MakeTokenShape()}); OP_REQUIRES(ctx, - xla::ShapeUtil::Compatible(cond.xla_output_shape, - expected_cond_output_shape), + xla::ShapeUtil::Compatible( + cond.xla_output_shape, + expected_cond_output_shape_without_side_effect) || + xla::ShapeUtil::Compatible( + cond.xla_output_shape, + expected_cond_output_shape_with_side_effect), errors::InvalidArgument( - "Output shape of loop condition should be (pred[]), got: ", + "Output shape of loop condition should be (pred[]) or " + "(pred[], token[]), got: ", xla::ShapeUtil::HumanString(cond.xla_output_shape))); int num_inputs = body.input_mapping.size(); @@ -283,11 +298,15 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init); - // Sets non-variable outputs. + // Sets non-variable outputs and determine when resource variables start. + int resource_index = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { ctx->SetOutput(body.input_mapping[i], xla::GetTupleElement(while_result, i)); + ++resource_index; + } else { + break; } } if (has_token_input_output_) { @@ -296,7 +315,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::GetTupleElement(while_result, ctx->num_outputs()); auto shape_or = builder->GetShape(token_output); OP_REQUIRES_OK(ctx, shape_or.status()); - OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + OP_REQUIRES(ctx, shape_or.ValueOrDie().IsToken(), errors::FailedPrecondition( "Token output is not token type: ", xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); @@ -309,7 +328,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); if (update.modified) { - int pos = body.outputs.size() + i; + int pos = resource_index + i; OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, @@ -329,8 +348,11 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Done building while loop"; } -REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp); -REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp); -REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); +REGISTER_XLA_OP(Name("While").AllowResourceTypes().AllowVariantTypes(), + XlaWhileOp); +REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes().AllowVariantTypes(), + XlaWhileOp); +REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes().AllowVariantTypes(), + XlaWhileOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc index 4612f19971a3ce6994aef303f751748b77ccda9a..b20adc592a0d3d2129c897218ddbfc891b4cd40a 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -78,7 +78,7 @@ class XlaConvOp : public XlaOpKernel { xla::XlaOp output = xla::ConvGeneralDilated( context->Input(0), context->Input(1), window_strides, padding, lhs_dilation, rhs_dilation, dnums_, feature_group_count, - &precision_config_); + /*batch_group_count=*/1, &precision_config_); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a30b4861f6b3a964c0c874a3affab7d6198264d7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/quantize.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaDequantizeOp : public XlaOpKernel { + public: + explicit XlaDequantizeOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("min_range", &min_range_)); + OP_REQUIRES_OK(context, context->GetAttr("max_range", &max_range_)); + OP_REQUIRES_OK(context, context->GetAttr("mode", &mode_)); + OP_REQUIRES_OK(context, + context->GetAttr("transpose_output", &transpose_output_)); + } + + void Compile(XlaOpKernelContext* context) override { + const xla::XlaOp& input = context->Input(0); + + xla::QuantizedRange range(min_range_, max_range_); + + xla::XlaOp output = + xla::Dequantize(input, range, mode_, transpose_output_); + context->SetOutput(0, output); + } + + private: + float min_range_; + float max_range_; + bool transpose_output_; + string mode_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaDequantizeOp); +}; + +REGISTER_XLA_OP(Name("XlaDequantize"), XlaDequantizeOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..233ac8e7b455403f8ee65b95b1403ecefdb92c6b --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" +#include "tensorflow/core/lib/core/bits.h" + +namespace tensorflow { +namespace { + +class XlaSelfAdjointEigOp : public XlaOpKernel { + public: + explicit XlaSelfAdjointEigOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("lower", &lower_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_iter", &max_iter_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); + } + void Compile(XlaOpKernelContext* ctx) override { + auto result = + xla::SelfAdjointEig(ctx->Input(0), lower_, max_iter_, epsilon_); + ctx->SetOutput(0, result.w); + ctx->SetOutput(1, result.v); + } + + private: + bool lower_; + int32 max_iter_; + float epsilon_; +}; + +class SelfAdjointEigV2Op : public XlaOpKernel { + public: + explicit SelfAdjointEigV2Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape("input"); + int n = input_shape.dim_size(input_shape.dims() - 1); + // This is based on heuristics that approx log(n) sweep updates are needed. + // Note: the heuristics provides no theoretical guarantee, max_iter=100 and + // epsilon should be used to determine exit condition. + int max_iter = 2 * tensorflow::Log2Ceiling(n); + auto result = xla::SelfAdjointEig(ctx->Input(0), true, max_iter, 1e-6); + ctx->SetOutput(0, result.w); + ctx->SetOutput(1, result.v); + } +}; + +REGISTER_XLA_OP(Name("XlaSelfAdjointEig").TypeConstraint("T", kFloatTypes), + XlaSelfAdjointEigOp); +REGISTER_XLA_OP(Name("SelfAdjointEigV2").TypeConstraint("T", kFloatTypes), + SelfAdjointEigV2Op); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 3e7a761120317ff85947559b7b2e52be9232afb7..3d7b0bc959f9dbf3c1b9749379e2ea0d285b302b 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -15,8 +15,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") - cc_library( name = "broadcast", srcs = ["broadcast.cc"], @@ -33,27 +31,6 @@ cc_library( ], ) -cc_library( - name = "cholesky", - srcs = ["cholesky.cc"], - hdrs = ["cholesky.h"], - deps = [ - ":util", - ":while_loop", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:matrix", - "//tensorflow/compiler/xla/client/lib:slicing", - "//tensorflow/compiler/xla/client/lib:triangular_solve", - "//tensorflow/core:lib", - ], -) - cc_library( name = "random", srcs = ["random.cc"], @@ -69,35 +46,12 @@ cc_library( ], ) -cc_library( - name = "qr", - srcs = ["qr.cc"], - hdrs = ["qr.h"], - deps = [ - ":util", - ":while_loop", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:matrix", - "//tensorflow/compiler/xla/client/lib:slicing", - "//tensorflow/core:lib", - ], -) - cc_library( name = "scatter", srcs = ["scatter.cc"], hdrs = ["scatter.h"], deps = [ ":util", - ":while_loop", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -128,19 +82,3 @@ cc_library( "@com_google_absl//absl/types:span", ], ) - -cc_library( - name = "while_loop", - srcs = ["while_loop.cc"], - hdrs = ["while_loop.h"], - deps = [ - ":util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 2b1c2ced925d9fee7392986015a6e716a94d356f..1cd5a79171dccd57fc1b7941cdf16417301ff7f8 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" @@ -49,7 +48,7 @@ xla::StatusOr XlaScatter( if (indices_are_vectors) { TF_RET_CHECK(!indices_dims.empty()); num_index_dims = indices_dims.back(); - if (num_index_dims > xla::ShapeUtil::Rank(buffer_shape)) { + if (num_index_dims > buffer_shape.rank()) { return errors::InvalidArgument( "The size of the minor dimension of the indices (shape: ", xla::ShapeUtil::HumanString(indices_shape), @@ -141,8 +140,8 @@ xla::StatusOr XlaScatter( ? indices_shape.dimensions_size() - 1 : indices_shape.dimensions_size()); - int64 updates_rank = xla::ShapeUtil::Rank(updates_shape); - int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape); + int64 updates_rank = updates_shape.rank(); + int64 buffer_rank = buffer_shape.rank(); int64 num_window_dims_in_updates = buffer_rank - num_index_dims; // If the rank of `updates` is 0 and does not match the expected rank of @@ -157,7 +156,7 @@ xla::StatusOr XlaScatter( if (updates_rank == 0 && expected_updates_rank != 0) { new_updates = xla::Broadcast(updates, expected_updates_dims); TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); - updates_rank = xla::ShapeUtil::Rank(updates_shape); + updates_rank = updates_shape.rank(); } if (updates_rank > 0) { diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index c0bd172d17c192435ba8ee196f9def0491c0bf5c..06eda41611861060a1f1c4d028b96405d288efdb 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -54,6 +54,9 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::C64: return xla::ConstantR0(builder, value); break; + case xla::C128: + return xla::ConstantR0(builder, value); + break; default: LOG(FATAL) << "unhandled element type " << type; } @@ -90,6 +93,9 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::C64: literal = xla::LiteralUtil::CreateR0(value); break; + case xla::C128: + literal = xla::LiteralUtil::CreateR0(value); + break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; case xla::S16: diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 67d08290033361f16dfff42b06af9b253e84963a..749a7c3054a65d6ec9f9dc13f6f4a713ac9d3d5a 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -77,7 +77,7 @@ Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor) { - TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && + TF_RET_CHECK(literal.shape().IsArray() && xla::ShapeUtil::ElementsIn(literal.shape()) == host_tensor->NumElements()); xla::PrimitiveType primitive_type; diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index 15f4c38da29507da9e092c1d5725b5f95a81d1b9..44bccfe6474d175beda392ca17dfbcb08c0b1b11 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -49,7 +49,7 @@ using Types = std::pair, std::pair, std::pair>; -TYPED_TEST_CASE(LiteralUtilTest, Types); +TYPED_TEST_SUITE(LiteralUtilTest, Types); TYPED_TEST(LiteralUtilTest, LiteralToQuantizedHostTensor) { using int_type = typename TypeParam::first_type; diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index 4dce0a2102cf9c782850ccc7af4f14b59bd51e53..7140b6a1227a53290c3747892a55886a7f48513b 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -4,7 +4,11 @@ package( licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_gen_op_wrapper_py", +) cc_library( name = "xla_ops", @@ -24,3 +28,14 @@ tf_gen_op_wrapper_py( ":xla_ops", ], ) + +tf_custom_op_library( + name = "_xla_ops.so", + srcs = [ + "xla_ops.cc", + ], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + ], +) diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index bd2c0a5ee88869ba60701c0a7ace05857452eed9..ccd58071d350e605e0e1f0c2b43643a400e32c2c 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -56,6 +56,41 @@ lhs_output: the broadcasted LHS tensor rhs_output: the broadcasted RHS tensor )doc"); +REGISTER_OP("XlaSelfAdjointEig") + .Input("a: T") + .Attr("lower: bool") + .Attr("max_iter: int") + .Attr("epsilon: float") + .Output("w: T") + .Output("v: T") + .SetShapeFn(shape_inference::UnknownShape) + .Attr("T: numbertype") + .Doc(R"doc( +Computes the eigen decomposition of a batch of self-adjoint matrices +(Note: Only real inputs are supported). + +Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in +tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for +i=0...N-1. + +a: the input tensor. + +lower: a boolean specifies whether the calculation is done with the lower + triangular part or the upper triangular part. + +max_iter: maximum number of sweep update, i.e., the whole lower triangular + part or upper triangular part based on parameter lower. Heuristically, it has + been argued that approximatly logN sweeps are needed in practice (Ref: Golub & + van Loan "Matrix Computation"). + +epsilon: the tolerance ratio. + +w: The eigenvalues in ascending order, each repeated according to its + multiplicity. +v: The column v[..., :, i] is the normalized eigenvector corresponding to the + eigenvalue w[..., i]. +)doc"); + REGISTER_OP("XlaConv") .Input("lhs: T") .Input("rhs: T") @@ -369,7 +404,11 @@ REGISTER_OP("XlaKeyValueSort") .Output("sorted_values: V") .Attr("K: realnumbertype") .Attr("V: type") - .SetShapeFn(shape_inference::UnchangedShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + c->set_output(1, c->input(1)); + return Status::OK(); + }) .Doc(R"doc( Wraps the XLA Sort operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#sort @@ -409,5 +448,29 @@ body: A function that takes a list of tensors and returns another list of tensors. Both lists have the same types as specified by T. )doc"); +REGISTER_OP("XlaDequantize") + .Input("input: uint32") + .Output("output: bfloat16") + .Attr("min_range: float") + .Attr("max_range: float") + .Attr("mode: string") + .Attr("transpose_output: bool") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Takes the packed uint32 input and unpacks the input to uint8 to do +Dequantization on deivce. + +input: Input tensors whose types is uint32, shape is [d0, ..., dn]. +output: Output tensors whose types is bloat16. If transpose_output is true, + output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output + is false, output shape is [d0,..., dn * 4]. +min_range: The minimum scalar value possibly produced for the input. +max_range: The maximum scalar value possibly produced for the input. +mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}. +transpose_output: Boolean to determine if output is transposed. transpose_output + is faster when input is large and rank of input is higher than 1. +)doc"); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index fef97b98c376d9df8bbfd9cb6651216895e46bf4..9abdb04d7736e8ff5225688af4759a522d3e7fc7 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -15,6 +15,7 @@ load( "//tensorflow/core:platform/default/build_config.bzl", "tf_py_clif_cc", ) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") tf_py_clif_cc( name = "xla_op_registry", @@ -27,9 +28,13 @@ tf_py_clif_cc( ], ) -py_library( +tf_custom_op_py_library( name = "xla", srcs = ["xla.py"], + dso = ["//tensorflow/compiler/tf2xla/ops:_xla_ops.so"], + kernels = [ + "//tensorflow/compiler/tf2xla/ops:xla_ops", + ], deps = [ "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", "//tensorflow/compiler/xla:xla_data_proto_py", diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 147e562658bbfc445f99268812e2c3ae1ee61e30..de4710d03a3e69afb04aa68e37961698f0e3a300 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -291,6 +291,10 @@ def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): name=name) +def self_adjoint_eig(a, lower, max_iter, epsilon): + return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon) + + dynamic_slice = gen_xla_ops.xla_dynamic_slice dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice @@ -386,3 +390,4 @@ def slice(x, start_dims, limit_dims, strides): sort = gen_xla_ops.xla_sort key_value_sort = gen_xla_ops.xla_key_value_sort while_loop = gen_xla_ops.xla_while +dequantize = gen_xla_ops.xla_dequantize diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 72b240996fb4d9dcb5f5dfd919da618cbae08c16..c20d6a5fd1f3bd7dad30cb3359d13ed4609a2250 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -65,6 +65,7 @@ CreateResourceOpInfoMap() { add("ResourceApplyFtrlV2" , kReadWrite, kVariable); add("ResourceApplyGradientDescent" , kReadWrite, kVariable); add("ResourceApplyMomentum" , kReadWrite, kVariable); + add("ResourceApplyKerasMomentum" , kReadWrite, kVariable); add("ResourceApplyPowerSign" , kReadWrite, kVariable); add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable); add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable); @@ -76,6 +77,7 @@ CreateResourceOpInfoMap() { add("ResourceScatterMin" , kReadWrite, kVariable); add("ResourceScatterMul" , kReadWrite, kVariable); add("ResourceScatterNdAdd" , kReadWrite, kVariable); + add("ResourceScatterNdSub" , kReadWrite, kVariable); add("ResourceScatterNdUpdate" , kReadWrite, kVariable); add("ResourceScatterSub" , kReadWrite, kVariable); add("ResourceScatterUpdate" , kReadWrite, kVariable); diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index b589512dcdfa32050281120aba6a5ae89a980c2f..8997b2f5c68da480e9d4cb1f7ff8776690363392 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -18,21 +18,81 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { +namespace { + +Status PopulateInfeedLayoutVector(const xla::Shape& shape, + std::vector* layouts) { + if (shape.IsTuple()) { + int64 tuple_elements = xla::ShapeUtil::TupleElementCount(shape); + for (int64 i = 0; i < tuple_elements; ++i) { + const xla::Shape& subshape = + xla::ShapeUtil::GetTupleElementShape(shape, i); + TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(subshape, layouts)); + } + } else if (xla::LayoutUtil::HasLayout(shape)) { + for (auto dim : xla::LayoutUtil::MinorToMajor(shape)) { + layouts->push_back(dim); + } + } else { + layouts->insert(layouts->end(), shape.rank(), -1); + } + return Status::OK(); +} + +// Populate the output layout unless the minor_to_major array contains all -1 +// value, in which case the layout is considered missing and the API returns +// false. +xla::StatusOr MakeLayout(absl::Span minor_to_major, + xla::Layout* layout) { + if (std::all_of(minor_to_major.begin(), minor_to_major.end(), + [](int64 dim) { return dim == -1; })) { + return false; + } + std::vector dim_present(minor_to_major.size(), false); + for (auto dim : minor_to_major) { + if (dim < 0 || dim >= minor_to_major.size()) { + return errors::InvalidArgument("Layout dimension out of range: dim=", dim, + " rank=", minor_to_major.size()); + } + if (dim_present[dim]) { + return errors::InvalidArgument("Repeated layout dimension: dim=", dim); + } + dim_present[dim] = true; + } + *layout = xla::LayoutUtil::MakeLayout(minor_to_major); + return true; +} + +Status AssignLayout( + absl::Span minor_to_major, + const std::function& layout_func, + xla::Shape* shape) { + xla::Layout layout; + TF_ASSIGN_OR_RETURN(bool has_layout, MakeLayout(minor_to_major, &layout)); + if (!has_layout && layout_func) { + layout = layout_func(*shape); + } + *shape->mutable_layout() = layout; + return Status::OK(); +} + +} // namespace // Convert an XLA Shape into the equivalent TensorFlow shape. Status XLAShapeToTensorShape(const xla::Shape& shape, TensorShape* tensor_shape) { - if (xla::ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { return errors::InvalidArgument("XLA shape ", xla::ShapeUtil::HumanString(shape), " cannot be converted to a TensorShape"); } *tensor_shape = TensorShape(); - for (int i = 0; i < xla::ShapeUtil::Rank(shape); ++i) { + for (int i = 0; i < shape.rank(); ++i) { tensor_shape->AddDim(shape.dimensions(i)); } return Status::OK(); @@ -61,4 +121,64 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); } +xla::StatusOr> GetShapeLayoutVector(const xla::Shape& shape) { + std::vector layouts; + TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(shape, &layouts)); + return layouts; +} + +Status GetShapeWithLayout( + const xla::Shape& input_shape, absl::Span minor_to_major, + const std::function& layout_func, + xla::Shape* output_shape) { + if (input_shape.IsTuple()) { + int64 tuple_elements = xla::ShapeUtil::TupleElementCount(input_shape); + std::vector shapes; + shapes.reserve(tuple_elements); + size_t position = 0; + for (int64 i = 0; i < tuple_elements; ++i) { + const xla::Shape& shape = + xla::ShapeUtil::GetTupleElementShape(input_shape, i); + if (shape.IsTuple()) { + return errors::InvalidArgument( + "Nested tuples not supported: ", + xla::ShapeUtil::HumanString(input_shape)); + } + int64 rank = shape.rank(); + if (position + rank > minor_to_major.size()) { + return errors::InvalidArgument( + "Not enough layout attribute elements: position=", position, + " rank=", rank, " elements=", minor_to_major.size()); + } + shapes.push_back(shape); + TF_RETURN_IF_ERROR(AssignLayout( + absl::Span(minor_to_major).subspan(position, rank), + layout_func, &shapes.back())); + position += rank; + + VLOG(4) << "Shape[" << i + << "] = " << xla::ShapeUtil::HumanStringWithLayout(shapes.back()); + } + if (position != minor_to_major.size()) { + return errors::InvalidArgument( + "Too many elements passed in the layout attribute: position=", + position, " size=", minor_to_major.size()); + } + *output_shape = xla::ShapeUtil::MakeTupleShape(shapes); + } else { + int64 rank = input_shape.rank(); + if (rank != minor_to_major.size()) { + return errors::InvalidArgument( + "Wrong number of layout attribute elements: rank=", rank, + " elements=", minor_to_major.size()); + } + *output_shape = input_shape; + TF_RETURN_IF_ERROR(AssignLayout(minor_to_major, layout_func, output_shape)); + + VLOG(4) << "Shape[] = " + << xla::ShapeUtil::HumanStringWithLayout(*output_shape); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index 0b231ea8e7a2d8e303e91911e2e0a36fc83e78b4..e775c4462c3dc15cf4b8d9e8d8e7d9a61e024cd0 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -18,7 +18,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ +#include + #include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" @@ -41,6 +44,25 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, const TensorShape& tensor_shape); +// Given an XLA shape with layouts, builds a layout vector in the form able to +// be fed to ops like InfeedEnqueue/InfeedEnqueueTuple/XRTAllocateV2/.... +// THe returned vector is a linearized sequence of the minor-to-major values of +// the layouts held within the input shape. +// In case the input shape is a tuple, the minor-to-major values will be in the +// order of the tuple elements within the tuple shape. +// If a shape (or a subshape of a tuple shape) has missing layout, a rank long +// sequence of -1 values will be emittted. +xla::StatusOr> GetShapeLayoutVector(const xla::Shape& shape); + +// Given the input shape and a linearized sequence of the minor-to-major values +// of the layouts, create the output shape by rewriting the input shape layouts. +// If a layout is missing (has -1 values) for a matching tuple subshape, the +// layout_func will be called, if not nullptr. +Status GetShapeWithLayout( + const xla::Shape& input_shape, absl::Span minor_to_major, + const std::function& layout_func, + xla::Shape* output_shape); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index b233e6b2c28e1968bb74901fc684e808ae45ab60..412f31adbb7df52b2d6933be054cc6d40947dc44 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -24,6 +24,51 @@ const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes"; const char kXlaTokenArgNodeName[] = "_xla_token_arg_node"; +const char kXlaHasHostTransferAttrName[] = "_xla_has_host_transfer"; + +Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { + if (!HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) { + return errors::InvalidArgument("Node ", node->DebugString(), + " does not have attribute ", + kXlaHasHostTransferAttrName); + } + + if (node->type_string() == "_XlaRecvAtHost" || + node->type_string() == "_XlaSendFromHost") { + node->ClearAttr("device_ordinal"); + node->AddAttr("device_ordinal", device_ordinal); + } else if (node->type_string() == "If") { + AttrValue device_ordinal_value; + device_ordinal_value.set_i(device_ordinal); + for (const string& attr_name : + std::vector{"then_branch", "else_branch"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + node->ClearAttr(attr_name); + node->AddAttr(attr_name, branch_func); + } + } else if (node->type_string() == "While") { + AttrValue device_ordinal_value; + device_ordinal_value.set_i(device_ordinal); + for (const string& attr_name : std::vector{"cond", "body"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + node->ClearAttr(attr_name); + node->AddAttr(attr_name, branch_func); + } + } else if (HasNodeAttr(node->def(), "device_ordinal")) { + // Function call node containing outside compilation. + node->ClearAttr("device_ordinal"); + node->AddAttr("device_ordinal", device_ordinal); + } else { + return errors::Internal("Unknown node type to set 'device_ordinal': ", + node->DebugString()); + } + return Status::OK(); +} + std::set CalculateTokenInputsForOutputToken(const Graph& g) { std::set results; Node* first_side_effecting_node_on_path = nullptr; diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index f22ddb2f58e1fa5c10ca0fdb956d9136942388b7..75e1f253fb08ae61b0336a8783b7449c69197dd1 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -35,6 +35,13 @@ extern const char kXlaTokenInputNodesAttrName[]; // node has side-effect dependency on current graph's token input. extern const char kXlaTokenArgNodeName[]; +// This node have XlaRecvAtHost/XlaSendFromHost in its associated functions. +extern const char kXlaHasHostTransferAttrName[]; + +// Sets device ordinal attribute for nodes with attribute +// `kXlaHasHostTransferAttrName`. +Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal); + // Calculates side-effect dependencies for the graph's token output. // Returns a set of node names representing these dependencies. std::set CalculateTokenInputsForOutputToken(const Graph& g); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 9fac16a9700419b189bf5393c2b8bd7d76c6c1cc..28a4566c9d284fb8410a2d618f368c4dd2c1d893 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -243,7 +243,9 @@ Status CreateXlaArgs(const Graph& graph, XlaCompiler::Argument arg; arg.kind = XlaCompiler::Argument::kParameter; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape)); + TensorShape shape; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape)); + arg.shape = shape; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } @@ -252,7 +254,8 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. -Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, +Status ConvertGraphToXla(std::unique_ptr graph, + const tf2xla::Config& config, xla::Client* client, xla::XlaComputation* computation) { XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { @@ -262,6 +265,19 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, std::vector xla_args; TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); + // Populate arguments with resource variables from the config. The variables + // get turned into inputs and outputs. + for (const tf2xla::Variable& variable : config.variable()) { + XlaCompiler::Argument arg; + arg.type = variable.type(); + arg.kind = XlaCompiler::Argument::kResource; + arg.shape = variable.shape(); + arg.name = variable.node_name(); + arg.resource_kind = XlaResource::kVariable; + arg.initialized = true; + xla_args.push_back(std::move(arg)); + } + // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; @@ -359,7 +375,8 @@ Status ConvertGraphDefToXla(const GraphDef& graph_def, xla::XlaComputation* computation) { std::unique_ptr graph; TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); - TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation)); + TF_RETURN_IF_ERROR( + ConvertGraphToXla(std::move(graph), config, client, computation)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla.proto b/tensorflow/compiler/tf2xla/tf2xla.proto index 18c9089f5fa0e9792a4763d9bfac4c4e826eb5b2..5627af7452b99da594c1c214d0b556d8d70544d5 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.proto +++ b/tensorflow/compiler/tf2xla/tf2xla.proto @@ -39,6 +39,15 @@ message Fetch { string name = 2; // Optional name for generated code. }; +// Variable represents a resource variable with the given name, shape and type. +message Variable { + string node_name = 1; + string name = + 2; // Optional name for generated code. If empty, node_name will be used. + TensorShapeProto shape = 3; + DataType type = 4; +} + // Config represents configuration information for tf2xla conversion. message Config { // Each feed is a positional input argument for the generated computation. @@ -47,4 +56,6 @@ message Config { // Each fetch is a positional output argument for the generated computation. // The order of each entry matches the order of each output argument. repeated Fetch fetch = 2; + // Each variable is a named input and output of the generated computation. + repeated Variable variable = 3; }; diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index ab26d939ccba75ce58609ffd71c7ccadbe90cfa8..24afe595b18b823818bd8fe65bc599af8bce040a 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -91,7 +91,7 @@ TEST(ConvertGraphDefToXla, Sum) { client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()}); TF_EXPECT_OK(result_or.status()); xla::Literal result = std::move(result_or.ValueOrDie()); - EXPECT_EQ("(s32[]) (\n42\n)", result.ToString()); + EXPECT_EQ("(\ns32[] 42\n)", result.ToString()); config.mutable_feed(0)->mutable_id()->set_output_index( 123); /* invalid output_index */ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index cc81772e8c5da710bc733f7e4f5fe820b2c2d110..88c03a6056ac6484013c3fd32c9889899b5c15c5 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -122,7 +122,12 @@ Status ReplaceArgUsageWithConstNode( for (const auto& iter : const_input_index_to_node) { int arg_index = iter.first; - Node* const_node = g->CopyNode(iter.second); + NodeDef const_def = iter.second->def(); + const_def.set_name(g->NewName(const_def.name())); + Status s; + Node* const_node = g->AddNode(const_def, &s); + TF_RETURN_IF_ERROR(s); + Node* arg_node = arg_nodes[arg_index]; // Collect all usages of the _Arg node. @@ -265,6 +270,13 @@ Status PropagateConstIntoWhileNode(Graph* g, Node* while_node, } // Check if i-th retval's input comes from i-th arg directly. + // For resource variable input of While nodes, TF2XLA convention is to place + // them at the end of all inputs (after all data inputs), and *not* return + // them. So number of While node inputs might be larger than number of its + // outputs. + if (i >= body_func->signature().output_arg_size()) { + continue; + } const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i); auto output_arg_input = body_func->ret().find(output_arg.name()); if (output_arg_input == body_func->ret().end()) { @@ -364,6 +376,7 @@ Status AddPlaceholdersForFeeds( GraphDef gd; *gd.mutable_versions() = graph_def->versions(); *gd.add_node() = *existing; + MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0)); TF_RETURN_IF_ERROR( AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/)); @@ -390,6 +403,7 @@ Status AddPlaceholdersForFeeds( // in this code. for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) { const PlaceholderInfo& info = it->second; + // TODO(shikharagarwal): Add original node information. NodeDef* d = graph_def->add_node(); d->set_name(info.placeholder_name); d->set_op("PlaceholderV2"); @@ -557,6 +571,12 @@ bool HasAssociatedFunction(const NodeDef& node_def, return true; } + if (node_def.op() == "XlaHostCompute") { + // XlaHostCompute has "shape_inference_graph" func attr, but that's not + // related to graph execution. + return false; + } + for (const auto& iter : node_def.attr()) { if (iter.second.has_func()) { return true; @@ -578,6 +598,9 @@ std::vector GetAssociatedFunctions( // This is a SymbolicGradient op. AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs)); + } else if (node.type_string() == "XlaHostCompute") { + // XlaHostCompute has "shape_inference_graph" func attr, but that's not + // related to graph execution. } else { // Collect all function attrs for the node. for (auto& iter : node.attrs()) { @@ -599,7 +622,9 @@ Status RewriteAssociatedFunction( switch (associated_function.type()) { case AssociatedFunctionInfo::kFunctionCallNode: { // Change this node to call the new function. - NodeDefBuilder builder(node->name(), rewritten_function_name, fld); + NodeDebugInfo debug_info(*node); + NodeDefBuilder builder(node->name(), rewritten_function_name, fld, + &debug_info); for (auto attr : node->attrs()) { builder.Attr(attr.first, attr.second); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 202e929315cacd4d6cdfc69d50639d8a427ec6c2..28b4744470e7d28863b5f7275f829b9bd59641e1 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -21,11 +21,13 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" @@ -329,5 +331,90 @@ TEST(CachedFunctionHandles, Basic) { TF_EXPECT_OK(cached_function_handles.ReleaseAllHandles()); } +TEST(PropagateConstIntoFunctionalNodes, WhileLoopWithResourceInput) { + FunctionLibraryDefinition fld(OpRegistry::Global(), {}); + { + // Cond graph & body graph. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto pred = ops::_Arg(scope.WithOpName("pred"), DT_BOOL, 0); + auto input = ops::_Arg(scope.WithOpName("input"), DT_RESOURCE, 1); + auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + FunctionDef cond_fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef)); + FunctionDef body_fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(body_fdef)); + } + Scope scope = Scope::NewRootScope().ExitOnError(); + auto pred = ops::Const(scope.WithOpName("pred"), false, TensorShape({})); + auto input = ops::Const(scope.WithOpName("input"), 0, TensorShape({})); + NameAttrList cond_fn, body_fn; + cond_fn.set_name("cond"); + body_fn.set_name("body"); + auto while_op = + ops::While(scope.WithOpName("while"), + std::initializer_list{pred, input}, cond_fn, body_fn); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + + TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld)); +} + +TEST(PropagateConstIntoFunctionalNodes, CopiedConstNodeHasUniqueName) { + FunctionLibraryDefinition fld(OpRegistry::Global(), {}); + { + // Cond graph & body graph. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto pred = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); + auto input = ops::_Arg(scope.WithOpName("arg1"), DT_BOOL, 1); + auto duplicate_name = ops::NoOp(scope.WithOpName("duplicate_name")); + auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + FunctionDef cond_fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef)); + FunctionDef body_fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(body_fdef)); + } + Scope scope = Scope::NewRootScope().ExitOnError(); + auto pred = + ops::Const(scope.WithOpName("duplicate_name"), false, TensorShape({})); + auto input = ops::Const(scope.WithOpName("input"), false, TensorShape({})); + NameAttrList cond_fn, body_fn; + cond_fn.set_name("cond"); + body_fn.set_name("body"); + auto while_op = + ops::While(scope.WithOpName("while"), + std::initializer_list{pred, input}, cond_fn, body_fn); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + + TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld)); + + // Check that in rewritten body function, the NoOp node still has name + // "duplicate_name", and the copied Const node has name "duplicate_name/_0". + auto node_name_index = graph.BuildNodeNameIndex(); + Node* while_node = node_name_index["while"]; + ASSERT_NE(while_node, nullptr); + TF_ASSERT_OK(GetNodeAttr(while_node->def(), "body", &body_fn)); + const FunctionDef* rewritten_body_fn = fld.Find(body_fn.name()); + ASSERT_NE(rewritten_body_fn, nullptr); + std::unordered_map nodes; + for (const NodeDef& node_def : rewritten_body_fn->node_def()) { + nodes[node_def.name()] = node_def; + } + auto noop_def = nodes.find("duplicate_name"); + ASSERT_NE(noop_def, nodes.end()); + EXPECT_EQ(noop_def->second.op(), "NoOp"); + auto const_def = nodes.find("duplicate_name/_0"); + ASSERT_NE(const_def, nodes.end()); + EXPECT_EQ(const_def->second.op(), "Const"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index d00b1376620c0c9d112c7d7426758f6d3f25e86f..732f957d7329c93ad104dacf5190948fbfd7974b 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -69,6 +69,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_COMPLEX64: *type = xla::C64; return Status::OK(); + case tensorflow::DT_COMPLEX128: + *type = xla::C128; + return Status::OK(); default: return errors::InvalidArgument( "Unsupported type in DataTypeToPrimitiveType ", diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index ddb284966eeb97cc7c9d3ed77fb313e567975e59..5bd0277c051711f2677b90a2679662899521e94a 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -60,8 +60,6 @@ class XlaCompilationAllocator : public Allocator { // buffers, so they get ids to track. bool ShouldAllocateEmptyTensors() override { return true; } - void GetStats(AllocatorStats* stats) override { stats->Clear(); } - private: // Don't run any constructors or destructors for complex objects, // since there is no backing store for the tensor to run them diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index c7341cf8b9e8d7a06fd304ae8766420d20f0c16e..de2e485a47c18ae8e58a06aba408dbb61a30d00a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -59,45 +59,8 @@ class XlaCompiledCpuFunction { // AOT this is backed by data compiled into the object file. // // The contents of StaticData are XLA-internal implementation details and - // should not be relied on by clients. - // - // TODO(sanjoy): Come up with a cleaner way to express the contraint we want - // here: generated XlaCompiledCpuFunction subclasses should be able to create - // instances of StaticData but only XlaCompiledCpuFunction should be able to - // read from StaticData instances. + // should not be relied on by clients (and therefore are private). class StaticData { - public: - void set_raw_function(RawFunction raw_function) { - raw_function_ = raw_function; - } - void set_buffer_infos( - const cpu_function_runtime::BufferInfo* buffer_infos) { - buffer_infos_ = buffer_infos; - } - void set_num_buffers(size_t num_buffers) { num_buffers_ = num_buffers; } - void set_arg_index_table(const int32* arg_index_table) { - arg_index_table_ = arg_index_table; - } - void set_num_args(int64 num_args) { num_args_ = num_args; } - void set_result_index(size_t result_index) { result_index_ = result_index; } - void set_arg_names(const char** arg_names) { arg_names_ = arg_names; } - void set_result_names(const char** result_names) { - result_names_ = result_names; - } - void set_program_shape(const xla::ProgramShapeProto* program_shape) { - program_shape_ = program_shape; - } - const xla::HloProfilePrinterData* hlo_profile_printer_data() const { - return hlo_profile_printer_data_; - } - void set_hlo_profile_printer_data( - const xla::HloProfilePrinterData* hlo_profile_printer_data) { - hlo_profile_printer_data_ = hlo_profile_printer_data; - } - void set_profile_counters_size(int64 profile_counters_size) { - profile_counters_size_ = profile_counters_size; - } - private: // The raw function to call. RawFunction raw_function_; @@ -134,7 +97,8 @@ class XlaCompiledCpuFunction { // declared so we don't have access to that information here. int64 profile_counters_size_ = 0; - // Only XlaCompiledCpuFunction is allowed to read the above fields. + // Only XlaCompiledCpuFunction is allowed to read and write the above + // fields. friend class XlaCompiledCpuFunction; }; @@ -148,7 +112,7 @@ class XlaCompiledCpuFunction { RESULTS_PROFILES_AND_TEMPS_ONLY, }; - XlaCompiledCpuFunction( + explicit XlaCompiledCpuFunction( const StaticData& static_data, AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS); virtual ~XlaCompiledCpuFunction(); @@ -280,6 +244,76 @@ class XlaCompiledCpuFunction { return *hlo_profile_printer_data_; } + protected: + // --------------------------------------------------------------------------- + // Accessors for reading from and writing to instances of `StaticData`. + // + // Classes generated by tfcompile can call these because the generated classes + // inherit from `XlaCompiledCpuFunction`. `XlaJitCompiledCpuFunction` can + // call these because it is explicitly added as a friend. + + static void set_static_data_raw_function(StaticData* static_data, + RawFunction raw_function) { + static_data->raw_function_ = raw_function; + } + + static void set_static_data_buffer_infos( + StaticData* static_data, + const cpu_function_runtime::BufferInfo* buffer_infos) { + static_data->buffer_infos_ = buffer_infos; + } + + static void set_static_data_num_buffers(StaticData* static_data, + size_t num_buffers) { + static_data->num_buffers_ = num_buffers; + } + + static void set_static_data_arg_index_table(StaticData* static_data, + const int32* arg_index_table) { + static_data->arg_index_table_ = arg_index_table; + } + + static void set_static_data_num_args(StaticData* static_data, + int64 num_args) { + static_data->num_args_ = num_args; + } + + static void set_static_data_result_index(StaticData* static_data, + size_t result_index) { + static_data->result_index_ = result_index; + } + + static void set_static_data_arg_names(StaticData* static_data, + const char** arg_names) { + static_data->arg_names_ = arg_names; + } + + static void set_static_data_result_names(StaticData* static_data, + const char** result_names) { + static_data->result_names_ = result_names; + } + + static void set_static_data_program_shape( + StaticData* static_data, const xla::ProgramShapeProto* program_shape) { + static_data->program_shape_ = program_shape; + } + + static void set_static_data_hlo_profile_printer_data( + StaticData* static_data, + const xla::HloProfilePrinterData* hlo_profile_printer_data) { + static_data->hlo_profile_printer_data_ = hlo_profile_printer_data; + } + + static const xla::HloProfilePrinterData* + get_static_data_hlo_profile_printer_data(StaticData* static_data) { + return static_data->hlo_profile_printer_data_; + } + + static void set_static_data_profile_counters_size( + StaticData* static_data, int64 profile_counters_size) { + static_data->profile_counters_size_ = profile_counters_size; + } + private: const RawFunction raw_function_; const size_t result_index_; @@ -313,6 +347,10 @@ class XlaCompiledCpuFunction { const char** result_names_ = nullptr; const xla::ProgramShapeProto* program_shape_ = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; + + // Add `XlaJitCompiledCpuFunction` as a friend so that it can access the + // `set_static_data_*` static methods above. + friend class XlaJitCompiledCpuFunction; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index ee461a3c07d4db514c7697e005a9371be4b54dd0..3221ec5b727de1f792cd61b792ee917588d56cf9 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -42,6 +43,8 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -57,7 +60,11 @@ Status CheckSignature(const DataTypeVector& types, " elements while function has ", types.size()); } for (int i = 0; i < types.size(); ++i) { - if (types[i] != args[i].type && types[i] != DT_RESOURCE) { + // Don't perform type checks on resource variables and tensor + // lists (DT_VARIANT) as we have to trick the type system in order to + // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor. + if (types[i] != args[i].type && types[i] != DT_RESOURCE && + types[i] != DT_VARIANT) { return errors::Internal( "Argument ", i, " has declared type ", DataTypeString(args[i].type), " but function parameter has type ", DataTypeString(types[i])); @@ -178,9 +185,10 @@ Status BuildComputation( std::vector elems; elems.reserve(retvals.size()); - // Keeps track of which retvals have layout to update. The first element is - // the output index, second element is the new layout. - std::vector> retval_to_update_layout; + // Keeps track of the layout of each retval. If a retval is not in this list, + // a descending layout is used. The first element is the output index, second + // element is the new layout. + std::vector> retval_index_and_layout; for (int i = 0; i < retvals.size(); ++i) { XlaCompiler::OutputDescription& output = (*outputs)[i]; const XlaExpression& retval = retvals[i]; @@ -192,6 +200,8 @@ Status BuildComputation( output.shape = output.constant_value.shape(); break; + case XlaExpression::Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; case XlaExpression::Kind::kXlaOp: { output.is_constant = false; TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); @@ -207,7 +217,7 @@ Status BuildComputation( TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( output.shape, output.type)); value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); - retval_to_update_layout.emplace_back(elems.size(), shape.layout()); + retval_index_and_layout.emplace_back(elems.size(), shape.layout()); } else if (it != retval_cores.end()) { // Apply the sharding to the output, if there is a core assignment. value = identity_op(value); @@ -280,6 +290,11 @@ Status BuildComputation( // Ensures the correct sharding is applied to the output. handle = identity_op(handle); + // Set layout of the retval to device representation layout. + if (resource->representation_shape().has_value()) { + retval_index_and_layout.emplace_back( + elems.size(), resource->representation_shape()->layout()); + } elems.push_back(handle); } } @@ -309,15 +324,15 @@ Status BuildComputation( computation->GetProgramShape()); *output_shape = program_shape.result(); // Update the output layout to the layout of retval. - for (auto& update : retval_to_update_layout) { + for (auto& index_and_layout : retval_index_and_layout) { if (!always_return_tuple && elems.size() == 1) { - *output_shape->mutable_layout() = update.second; + *output_shape->mutable_layout() = index_and_layout.second; continue; } - xla::Shape* output_sub_shape = - xla::ShapeUtil::GetMutableSubshape(output_shape, {update.first}); - *output_sub_shape->mutable_layout() = update.second; + xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape( + output_shape, {index_and_layout.first}); + *output_sub_shape->mutable_layout() = index_and_layout.second; } return Status::OK(); } @@ -333,8 +348,21 @@ bool XlaCompiler::Argument::operator==( other.tensor_array_gradients)) { return false; } - if (shape != other.shape) { - return false; + if (absl::holds_alternative(shape)) { + if (!absl::holds_alternative(other.shape)) { + return false; + } + if (!xla::Shape::Equal()(absl::get(shape), + absl::get(other.shape))) { + return false; + } + } else { + if (!absl::holds_alternative(other.shape)) { + return false; + } + if (absl::get(shape) != absl::get(other.shape)) { + return false; + } } if (constant_value.shape() != other.constant_value.shape()) { return false; @@ -348,7 +376,7 @@ string XlaCompiler::Argument::HumanString() const { common = absl::StrCat(" name=", name); } absl::StrAppend(&common, " type=", DataTypeString(type), - " shape=", shape.DebugString()); + " shape=", ShapeHumanString()); switch (kind) { case kInvalid: return "invalid"; @@ -375,6 +403,23 @@ string XlaCompiler::Argument::HumanString() const { } } +std::vector XlaCompiler::Argument::DimensionSizes() const { + if (absl::holds_alternative(shape)) { + return xla::InlinedVectorToVector( + absl::get(shape).dim_sizes()); + } else { + return absl::get(shape).dimensions(); + } +} + +string XlaCompiler::Argument::ShapeHumanString() const { + if (absl::holds_alternative(shape)) { + return absl::get(shape).DebugString(); + } else { + return absl::get(shape).DebugString(); + } +} + XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), @@ -462,8 +507,34 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { opts.set_do_function_inlining(true); opts.set_do_constant_folding(true); GraphOptimizer optimizer(opts); + // Do not constant fold nodes that output DT_VARIANT type tensors. + // XLA does not support Const nodes of Variant type since it needs + // to know the original ops to be able to compile them to the relevant + // XLA form. + // TODO(srbs): This filter is a little conservative. E.g. a subgraph of + // the form: + // Const + // | + // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op + // | + // (Discard popped list) + // + // Would have been reduced to "Const -> Op" without this filter. + // However since we are only allowed to specify the filter at the "Node" + // level there is no good way to allow the above behavior. So we + // disallow any sort of constant folding on Variant nodes for now. + auto cf_consider_fn = [](const Node* n) { + for (const auto& output_arg : n->op_def().output_arg()) { + if (output_arg.type() == DT_VARIANT) { + return false; + } + } + return true; + }; + GraphOptimizer::Options graph_optimizer_options; + graph_optimizer_options.cf_consider_fn = cf_consider_fn; optimizer.Optimize(flib_runtime_, flib_runtime_->env(), - /*device=*/nullptr, &graph, /*shape_map=*/nullptr); + /*device=*/nullptr, &graph, graph_optimizer_options); return graph; } @@ -548,11 +619,22 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, LOG(FATAL) << "Unreachable case"; case XlaCompiler::Argument::kParameter: { if (is_entry_computation) { - TF_ASSIGN_OR_RETURN( - *xla_shape, options_.shape_representation_fn(arg.shape, arg.type)); + TensorShape shape; + if (absl::holds_alternative(arg.shape)) { + shape = absl::get(arg.shape); + } else { + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(absl::get(arg.shape), &shape)); + } + TF_ASSIGN_OR_RETURN(*xla_shape, + options_.shape_representation_fn(shape, arg.type)); } else { - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(arg.type, arg.shape, xla_shape)); + if (absl::holds_alternative(arg.shape)) { + *xla_shape = absl::get(arg.shape); + } else { + TF_RETURN_IF_ERROR(TensorShapeToXLAShape( + arg.type, absl::get(arg.shape), xla_shape)); + } } return Status::OK(); } @@ -561,8 +643,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, switch (arg.resource_kind) { case XlaResource::kVariable: { - TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( - arg.shape, arg.type)); + TF_RET_CHECK(absl::holds_alternative(arg.shape)); + TF_ASSIGN_OR_RETURN(*xla_shape, + options_.shape_representation_fn( + absl::get(arg.shape), arg.type)); return Status::OK(); } @@ -571,9 +655,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, return errors::InvalidArgument( "Negative max_array_size in XLAShapeForArgument"); } + TF_RET_CHECK(absl::holds_alternative(arg.shape)); TensorShape shape; shape.AddDim(arg.max_array_size); - shape.AppendShape(arg.shape); + shape.AppendShape(absl::get(arg.shape)); TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape)); if (!arg.tensor_array_gradients.empty()) { @@ -588,9 +673,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, return errors::InvalidArgument( "Negative max_array_size in XLAShapeForArgument"); } + TF_RET_CHECK(absl::holds_alternative(arg.shape)); TensorShape shape; shape.AddDim(arg.max_array_size); - shape.AppendShape(arg.shape); + shape.AppendShape(absl::get(arg.shape)); xla::Shape buffer_shape; TF_RETURN_IF_ERROR( TensorShapeToXLAShape(arg.type, shape, &buffer_shape)); @@ -620,14 +706,15 @@ Status XlaCompiler::BuildArguments( bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, const std::map& arg_cores, std::vector* arg_expressions, - std::vector* input_mapping, std::vector* input_shapes, + std::vector* input_to_args, std::vector* input_shapes, bool is_entry_computation) { arg_expressions->resize(args.size()); // Argument numbers of arguments and resources that are to be passed to the - // XLA computation as runtime parameters. - input_mapping->clear(); - input_mapping->reserve(args.size()); + // XLA computation as runtime parameters. `input_to_args[a] = b` means that + // the a'th XLA input corresponds to the b'th original arg indexes. + input_to_args->clear(); + input_to_args->reserve(args.size()); // Fills in constant arguments, and computes non-constant argument order. for (std::vector::size_type i = 0; i < args.size(); @@ -637,24 +724,25 @@ Status XlaCompiler::BuildArguments( switch (arg.kind) { case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid); + TF_RET_CHECK(absl::holds_alternative(arg.shape)); // TODO(phawkins): this code assumes that resource arguments do not // alias. XlaResource* resource = context->AddResource(absl::make_unique( - arg.resource_kind, i, arg.name, arg.type, arg.shape, - xla::XlaOp(), + arg.resource_kind, i, arg.name, arg.type, + absl::get(arg.shape), xla::XlaOp(), /*max_array_size=*/arg.max_array_size, /*tensor_array_gradients=*/arg.tensor_array_gradients, /*tensor_array_multiple_writes_aggregate=*/true)); arg_expression = XlaExpression::Resource(resource); if (arg.initialized) { - input_mapping->push_back(i); + input_to_args->push_back(i); } break; } case XlaCompiler::Argument::kParameter: case XlaCompiler::Argument::kToken: { - input_mapping->push_back(i); + input_to_args->push_back(i); break; } case XlaCompiler::Argument::kConstant: @@ -666,15 +754,23 @@ Status XlaCompiler::BuildArguments( } } - if (input_mapping->empty()) { + if (input_to_args->empty()) { return Status::OK(); } - std::vector arg_shapes(input_mapping->size()); - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds + // to the d'th XLA input. Note that the value -1 corresponds to constants, or + // other args that don't correspond to an input. + std::vector arg_to_inputs(args.size(), -1); + for (int i = 0; i < input_to_args->size(); i++) { + arg_to_inputs[input_to_args->at(i)] = i; + } + + std::vector arg_shapes(input_to_args->size()); + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { // Computes the shapes of non-constant arguments. TF_RETURN_IF_ERROR(XLAShapeForArgument( - args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i])); + args[(*input_to_args)[i]], is_entry_computation, &arg_shapes[i])); } if (use_tuple_arg) { @@ -691,13 +787,13 @@ Status XlaCompiler::BuildArguments( builder->SetOpMetadata(arg_metadata); // Build parameter handles for non-constant arguments. - std::vector arg_handles(input_mapping->size()); + std::vector arg_handles(input_to_args->size()); if (use_tuple_arg) { xla::XlaOp tuple; if (is_entry_computation) { xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); - for (int64 parameter : *input_mapping) { + for (int64 parameter : *input_to_args) { auto it = arg_cores.find(parameter); const int core = it == arg_cores.end() ? 0 : it->second; *tuple_sharding.add_tuple_shardings() = @@ -709,7 +805,19 @@ Status XlaCompiler::BuildArguments( } else { tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + + for (int i = 0; i < input_to_args->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; + for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { + int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); + TF_RETURN_IF_ERROR(builder->SetDynamicBinding( + /*dynamic_size_param_num=*/0, {dynamic_size_param_index}, + /*target_param_num=*/0, /*target_param_index=*/{i}, + dim_and_arg_num.first)); + } + } + + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { auto it = arg_cores.find(i); const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( @@ -718,7 +826,7 @@ Status XlaCompiler::BuildArguments( arg_handles[i] = xla::GetTupleElement(tuple, i); } } else { - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { auto it = arg_cores.find(i); const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( @@ -727,6 +835,17 @@ Status XlaCompiler::BuildArguments( arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], absl::StrCat("arg", i)); } + + for (int i = 0; i < input_to_args->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; + for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { + int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); + TF_RETURN_IF_ERROR(builder->SetDynamicBinding( + /*dynamic_size_param_num=*/dynamic_size_param_index, {}, + /*target_param_num=*/i, /*target_param_index=*/{}, + dim_and_arg_num.first)); + } + } } builder->ClearOpMetadata(); @@ -734,12 +853,12 @@ Status XlaCompiler::BuildArguments( // Fill in the handles in non-constant arguments, and reshape parameters // back to their correct shapes. VLOG(2) << "XLA computation inputs:"; - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { - const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; VLOG(2) << " XLA arg " << i << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) - << " name: " << arg.name << " TF arg " << input_mapping->at(i); - XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)]; + << " name: " << arg.name << " TF arg " << input_to_args->at(i); + XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)]; switch (arg.kind) { case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); @@ -756,7 +875,7 @@ Status XlaCompiler::BuildArguments( // return values of functions, and then reshape unconditionally. if (is_entry_computation) { arg_expression = XlaExpression::XlaOp( - xla::Reshape(arg_handles[i], arg.shape.dim_sizes()), arg.type); + xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type); } else { arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); } @@ -997,8 +1116,17 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->outputs.resize(context->retvals().size()); std::vector retvals = context->retvals(); if (options.resolve_compile_time_constants) { - TF_RETURN_IF_ERROR(ResolveConstantExpressionsToConstants( - client(), absl::Span(retvals))); + Status status = ResolveConstantExpressionsToConstants( + client(), absl::Span(retvals)); + + // If the HloEvaluator has not implemented an expression, just evaluate it + // at runtime. + if (status.code() == error::UNIMPLEMENTED) { + ConvertConstantsToExpressions(&builder, + absl::Span(retvals)); + } else { + TF_RETURN_IF_ERROR(status); + } } else { ConvertConstantsToExpressions(&builder, absl::Span(retvals)); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 0d801b73a8c2651305328384377751254ecaa41d..ad3144b41bdf3fc8b75ab5230e8e128df2962884 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "absl/types/variant.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" @@ -124,7 +125,8 @@ class XlaCompiler { DataType type = DT_INVALID; // The shape of the argument. For: - // * a parameter: the shape of the parameter. + // * a parameter: the shape of the parameter. We allow setting the xla shape + // if known. This helps avoid conversions to and from TensorShape. // * a constant: ignored; the shape given by constant_value is used // instead. // * an uninitialized resource: ignored. We don't yet know the shape of an @@ -133,7 +135,7 @@ class XlaCompiler { // * an initialized TensorArray or Stack resource: the shape of an entry in // the TensorArray/Stack. Note this is the size of a single entry, not the // XLA data structure that represents the complete stack/array. - TensorShape shape; + absl::variant shape; // The value of the argument, if it is a compile-time constant. Must be a // host-memory tensor. @@ -157,10 +159,20 @@ class XlaCompiler { // as `tensor_array_gradients`. std::set tensor_array_gradients; + // dynamic dims to arg number map. Empty if no dynamic shapes. + std::map dynamic_dim_to_arg_num_map; + bool is_pad_arg = false; + bool operator==(const Argument& other) const; // Returns a human-readable summary of the argument. string HumanString() const; + + // Returns the dimension sizes for either TensorShape or xla::Shape. + std::vector DimensionSizes() const; + + // Returns the human-readable string for either TensorShape or xla::Shape. + string ShapeHumanString() const; }; // Options pertaining to an individual call to CompileGraph() or @@ -420,7 +432,7 @@ class XlaCompiler { XlaContext* context, const std::map& arg_cores, std::vector* arg_expressions, - std::vector* input_mapping, + std::vector* input_to_args, std::vector* input_shapes, bool is_entry_computation); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index fe2a5f5b0c9ea6b5f2bb71df836fdcabf9a0cf23..b31137867d738944eaaa73e142ad8538ec6b854a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -82,7 +82,7 @@ namespace { // compiled kernels. class DummyResourceForTest : public ResourceBase { public: - string DebugString() override { return "dummy"; } + string DebugString() const override { return "dummy"; } void Increment() { ++value_; } int Get() { return value_; } @@ -277,6 +277,97 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal)); } +// Tests that the compiler can correctly propagate the layout assigned by +// shape_representation_fn_ to return types. +TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 3}); + + auto options = DefaultOptions(); + options.shape_representation_fn = + [](const TensorShape& shape, DataType dt) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape)); + *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); + return xla_shape; + }; + // Compiles the graph. + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + xla::Shape transposed = + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}); + // Check that the return shapes are correctly tranposed. + EXPECT_EQ(result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({transposed, transposed})); +} + +// The layout of resource variable shouldn't change after transpose +TEST_F(XlaCompilerTest, TransposeVariables) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto transposed_read = ops::Transpose(scope, read, {1, 0}); + auto reshape = ops::Reshape(scope, transposed_read, {2, 3}); + auto d = ops::_Retval(scope.WithOpName("D"), reshape, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 3}); + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose", + std::move(graph), args, &result)); + xla::Shape transposed = + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0}); + // Check that the return shapes are correctly tranposed. + EXPECT_EQ(result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({transposed, transposed})); +} + // Tests that the compiler doesn't reorder the parameters. TEST_F(XlaCompilerTest, MixedOrderArguments) { for (bool swap_order : {false, true}) { @@ -1362,7 +1453,7 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), args, &result)); EXPECT_EQ(result.xla_input_shapes.size(), 1); - EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_TRUE(result.xla_output_shape.IsTuple()); EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1); } { @@ -1380,11 +1471,11 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), args, &result)); EXPECT_EQ(result.xla_input_shapes.size(), 2); - EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[1])); - EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_TRUE(result.xla_input_shapes[1].IsToken()); + EXPECT_TRUE(result.xla_output_shape.IsTuple()); EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 2); - EXPECT_TRUE(xla::ShapeUtil::IsToken( - xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 1))); + EXPECT_TRUE(xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 1) + .IsToken()); } } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index a69af70503376b6c0905deb8980abdc3254a6e47..3f787fd86c9f7366a7728dcf146a3797ba672bc3 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -61,7 +61,7 @@ void XlaContext::set_args(std::vector args) { XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder) : compiler_(compiler), builder_(builder) {} -string XlaContext::DebugString() { return "XLA JIT context"; } +string XlaContext::DebugString() const { return "XLA JIT context"; } void XlaContext::SetRetval(int index, const XlaExpression& expression) { if (retvals_.size() <= index) { @@ -76,7 +76,7 @@ XlaResource* XlaContext::AddResource(std::unique_ptr resource) { } const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { - return LookupOrCreate(type, &max_func_, [this, type] { + return LookupOrCreate(type, &max_func_, [type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Max() for " << type_string; xla::XlaBuilder b("max<" + type_string + ">"); @@ -92,7 +92,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { } const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { - return LookupOrCreate(type, &min_func_, [this, type] { + return LookupOrCreate(type, &min_func_, [type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Min() for " << type_string; xla::XlaBuilder b("min<" + type_string + ">"); @@ -108,7 +108,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { } const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { - return LookupOrCreate(type, &add_func_, [this, type] { + return LookupOrCreate(type, &add_func_, [type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Add() for " << type_string; xla::XlaBuilder b("add<" + type_string + ">"); @@ -124,7 +124,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { } const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) { - return LookupOrCreate(type, &mul_func_, [this, type] { + return LookupOrCreate(type, &mul_func_, [type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Mul() for " << type_string; xla::XlaBuilder b("mul<" + type_string + ">"); diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 0767d1faac14cedb8666f6cc37175eb7b55f6158..eb4ad3fe6a14b42a4df2c73c71cb6df1331fd796 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -47,7 +47,7 @@ class XlaContext : public ResourceBase { XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder); // Virtual method defined by ResourceBase. - string DebugString() override; + string DebugString() const override; XlaCompiler* compiler() const { return compiler_; } diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index ca0309166b7c73d1a5a818091e2a30fa112a4de4..3d228c92adcbe3d093a4fe70d157e57ab3e80c80 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -46,6 +46,14 @@ XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) { return e; } +XlaExpression XlaExpression::TensorList(xla::XlaOp tensor_list) { + XlaExpression e; + e.kind_ = Kind::kTensorList; + e.dtype_ = DT_VARIANT; + e.handle_ = tensor_list; + return e; +} + XlaExpression XlaExpression::Resource(XlaResource* resource) { XlaExpression e; e.kind_ = Kind::kResource; @@ -64,6 +72,8 @@ string XlaExpression::HumanString() const { return "xla_op"; case Kind::kResource: return "resource"; + case Kind::kTensorList: + return "tensor_list"; } } @@ -76,6 +86,8 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { HostTensorToBorrowingLiteral(constant_value_, &literal)); return xla::ConstantLiteral(builder, literal); } + case Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; case Kind::kXlaOp: if (builder != handle_.builder()) { return errors::InvalidArgument( @@ -96,7 +108,10 @@ xla::StatusOr> XlaExpression::ResolveConstant( return {constant_value()}; case Kind::kXlaOp: break; + case Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; case Kind::kResource: + TF_FALLTHROUGH_INTENDED; case Kind::kInvalid: return errors::InvalidArgument( "ResolveConstant called on XlaExpression: ", HumanString()); @@ -134,6 +149,8 @@ xla::StatusOr XlaExpression::GetShape() const { TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); return shape; } + case Kind::kTensorList: + return TensorShape({}); case Kind::kResource: return TensorShape({}); case Kind::kInvalid: diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index bed6761d362a98d344003c1edea342e68c31ef07..ac0232d8924cf2c9e35ad3f0772a3a2adc18af87 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -32,11 +32,16 @@ namespace tensorflow { // * a constant tensor. // * an xla::XlaOp, representing a symbolic XLA value. // * a resource, e.g., a variable, represented as an XlaResource pointer. +// * a tensor list, represented by a tuple of tensors and the list length. // // Constant tensors are mostly an optimization to avoid passing large constants // to XLA, but are also sometimes used to represent tensors that have no XLA // representation, for example, DT_STRING tensors. A canonical use case might be // an error message string. +// +// Tensor lists are very similar to xla::XlaOp, however they require some +// specific logic around shape management since the tuples are not supported by +// TensorFlow. class XlaExpression { public: enum class Kind { @@ -44,6 +49,7 @@ class XlaExpression { kConstant, kXlaOp, kResource, + kTensorList, }; XlaExpression(); @@ -62,6 +68,9 @@ class XlaExpression { // be derived from the XLA type. static XlaExpression XlaOp(xla::XlaOp value, DataType dtype); + // Builds a tensor list expression. + static XlaExpression TensorList(xla::XlaOp tensor_list); + // Builds a resource expression. static XlaExpression Resource(XlaResource* resource); @@ -100,7 +109,8 @@ class XlaExpression { DataType dtype_ = DT_INVALID; - // The XLA handle of the expression's computation, if kind_ == kXlaOp. + // The XLA handle of the expression's computation, if kind_ == kXlaOp or + // a tuple expression if kind_ == kTensorList. xla::XlaOp handle_; // The value of the constant, if kind_ == kConstant. diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index c2c0751211180c3715a19d6c78e34659fd18914e..7bb1ad27467a5b281626de4203169e575288f9ee 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -34,63 +34,6 @@ limitations under the License. namespace tensorflow { -namespace { - -xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis, - bool is_min) { - xla::XlaBuilder* builder = input.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); - xla::XlaOp init_value; - xla::XlaComputation reducer; - if (is_min) { - init_value = xla::MaxValue(builder, input_shape.element_type()); - reducer = - xla::CreateScalarMinComputation(input_shape.element_type(), builder); - } else { - init_value = xla::MinValue(builder, input_shape.element_type()); - reducer = - xla::CreateScalarMaxComputation(input_shape.element_type(), builder); - } - - xla::XlaOp input_max = xla::Reduce(input, init_value, reducer, - /*dimensions_to_reduce=*/{axis}); - std::vector broadcast_dims(xla::ShapeUtil::Rank(input_shape) - 1); - std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); - std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - // Compute a mask that has 1s for elements equal to the maximum. - xla::XlaOp partial_mask = xla::ConvertElementType( - xla::Eq(input, input_max, broadcast_dims), output_type); - - // In order to make identity elements for a bitwise And, we: - // Left shift the 1 to the leftmost bit, yielding 0x10...0 - // Arithmetic right shift the 1 back to the rightmost bit, yielding - // 0xFF...F - int32 bits_in_type = - xla::ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1; - xla::XlaOp shift_amount = - xla::ConstantR0WithType(builder, output_type, bits_in_type); - xla::XlaOp full_mask = xla::ShiftRightArithmetic( - xla::ShiftLeft(partial_mask, shift_amount), shift_amount); - - // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its - // index. - - const int64 axis_size = xla::ShapeUtil::GetDimension(input_shape, axis); - xla::XlaOp iota = xla::Iota(builder, output_type, axis_size); - xla::XlaOp product = - xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis}); - - // If there are multiple maximum elements, choose the one with the highest - // index. - return xla::Reduce(product, xla::MinValue(builder, output_type), - xla::CreateScalarMaxComputation(output_type, builder), - /*dimensions_to_reduce=*/{axis}); - }); -} - -} // namespace - xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); @@ -120,7 +63,7 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, /* static */ Status XlaHelpers::ReshapeLiteral( const xla::Literal& input, absl::Span dimensions, xla::Literal* output) { - if (xla::ShapeUtil::IsTuple(input.shape())) { + if (input.shape().IsTuple()) { return errors::InvalidArgument("ReshapeLiteral does not support tuples."); } xla::Shape shape = @@ -138,71 +81,27 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, return Status::OK(); } -template -static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { - Tensor linspace(DataTypeToEnum::v(), shape); - auto linspace_flat = linspace.flat(); - for (int64 i = 0; i < depth; ++i) { - linspace_flat(i) = i; - } - return linspace; -} - -xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, - int axis) { - return ArgMinMax(input, output_type, axis, /*is_min=*/false); -} - -xla::XlaOp XlaHelpers::ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, - int axis) { - return ArgMinMax(input, output_type, axis, /*is_min=*/true); -} - Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, const xla::XlaOp& indices, const xla::XlaOp& on_value, const xla::XlaOp& off_value, xla::XlaOp* one_hot) { - const int indices_dims = indices_shape.dims(); - const int output_dims = indices_dims + 1; - - TensorShape output_shape = indices_shape; - output_shape.InsertDim(axis, depth); - - // Build a Tensor populated with values 0, 1, 2, ... depth. - std::vector linspace_dims(output_dims, 1); - linspace_dims[axis] = depth; - TensorShape linspace_shape(linspace_dims); - Tensor linspace; - switch (index_type) { - case DT_UINT8: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - case DT_INT32: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - case DT_INT64: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - default: - return errors::InvalidArgument("Invalid argument type ", - DataTypeString(index_type)); - } - - xla::BorrowingLiteral linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); - // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. std::vector broadcast_dims(indices_shape.dims()); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - xla::XlaOp one_hot_bool = xla::Eq( - indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims); + + TensorShape output_shape = indices_shape; + output_shape.InsertDim(axis, depth); + xla::Shape iota_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(index_type, output_shape, &iota_shape)); // Selects the user-provided off_value and on_value values. - *one_hot = xla::Select(one_hot_bool, - xla::Broadcast(on_value, output_shape.dim_sizes()), - xla::Broadcast(off_value, output_shape.dim_sizes())); + *one_hot = xla::Select( + xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims), + xla::Broadcast(on_value, output_shape.dim_sizes()), + xla::Broadcast(off_value, output_shape.dim_sizes())); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 4858dfee55a393d04cd2af83916eeb40820ee368..490923526bd3acd4b167ccb3faff1d6c9e631131 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -53,16 +53,6 @@ class XlaHelpers { absl::Span shape, xla::Literal* output); - // Returns the argmax of `input` along `axis`. `output_type` is the type to - // use for the output. - static xla::XlaOp ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, - int axis); - - // Returns the argmin of `input` along `axis`. `output_type` is the type to - // use for the output. - static xla::XlaOp ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, - int axis); - // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new // axis. `indices_shape` is the shape of `indices`. `on_value` and diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index fabbcd04fed96ad814d04c2df9394f43bfe0cf99..884dc45cb11b18ae557c3da3f4192b3805cb7980 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -135,24 +135,34 @@ XlaJitCompiledCpuFunction::Compile( jit->arg_index_table_ = std::move(arg_index_table); jit->program_shape_ = absl::make_unique(program_shape->ToProto()); - jit->static_data_.set_raw_function(raw_function); - jit->static_data_.set_buffer_infos(jit->buffer_infos_.data()); - jit->static_data_.set_num_buffers(jit->buffer_infos_.size()); - jit->static_data_.set_arg_index_table(jit->arg_index_table_.data()); - jit->static_data_.set_num_args(jit->arg_index_table_.size()); - jit->static_data_.set_result_index(result_index); + XlaCompiledCpuFunction::set_static_data_raw_function(&jit->static_data_, + raw_function); + XlaCompiledCpuFunction::set_static_data_buffer_infos( + &jit->static_data_, jit->buffer_infos_.data()); + XlaCompiledCpuFunction::set_static_data_num_buffers( + &jit->static_data_, jit->buffer_infos_.size()); + XlaCompiledCpuFunction::set_static_data_arg_index_table( + &jit->static_data_, jit->arg_index_table_.data()); + XlaCompiledCpuFunction::set_static_data_num_args( + &jit->static_data_, jit->arg_index_table_.size()); + XlaCompiledCpuFunction::set_static_data_result_index(&jit->static_data_, + result_index); // Optional metadata is collected and set below. CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); CollectNames(config.fetch(), &jit->nonempty_result_names_, &jit->result_names_); - jit->static_data_.set_arg_names(jit->arg_names_.data()); - jit->static_data_.set_result_names(jit->result_names_.data()); - jit->static_data_.set_program_shape(jit->program_shape_.get()); + XlaCompiledCpuFunction::set_static_data_arg_names(&jit->static_data_, + jit->arg_names_.data()); + XlaCompiledCpuFunction::set_static_data_result_names( + &jit->static_data_, jit->result_names_.data()); + XlaCompiledCpuFunction::set_static_data_program_shape( + &jit->static_data_, jit->program_shape_.get()); if (cpu_executable->hlo_profiling_enabled()) { - jit->static_data_.set_hlo_profile_printer_data( - &cpu_executable->hlo_profile_printer_data()); - jit->static_data_.set_profile_counters_size( + XlaCompiledCpuFunction::set_static_data_hlo_profile_printer_data( + &jit->static_data_, &cpu_executable->hlo_profile_printer_data()); + XlaCompiledCpuFunction::set_static_data_profile_counters_size( + &jit->static_data_, cpu_executable->hlo_profile_printer_data().profile_counters_size()); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 58808c76de6330a6b28e21dbdead03dea25847f6..ee11f3a3de658c7e5108605122b84fbc3e1cd963 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -93,7 +93,7 @@ TensorShape XlaOpKernelContext::InputShape(absl::string_view name) { } DataType XlaOpKernelContext::input_type(int index) const { - return context_->input(index).dtype(); + return context_->input_dtype(index); } DataType XlaOpKernelContext::InputType(absl::string_view name) { @@ -178,7 +178,7 @@ Status XlaOpKernelContext::ConstantInputReshaped( // Converts an int32 or int64 scalar literal to an int64. static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, int64* out) { - if (xla::ShapeUtil::Rank(literal.shape()) != 0) { + if (literal.shape().rank() != 0) { return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::S32) { @@ -194,7 +194,7 @@ static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, // Converts an float32 or float64 scalar literal to a float64. static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, double* out) { - if (xla::ShapeUtil::Rank(literal.shape()) != 0) { + if (literal.shape().rank() != 0) { return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::F32) { @@ -228,8 +228,9 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { // Converts an int32 or int64 1D literal to an int64 vector. static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, std::vector* out) { - if (xla::ShapeUtil::Rank(literal.shape()) != 1) { - return errors::InvalidArgument("value is not 1D"); + if (literal.shape().rank() != 1) { + return errors::InvalidArgument("value is not 1D, rank: ", + literal.shape().rank()); } int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); if (literal.shape().element_type() == xla::S32) { @@ -318,6 +319,27 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { return Status::OK(); } +Status XlaOpKernelContext::ConstantInputAsPartialShape( + int index, PartialTensorShape* shape) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); + // If `literal` is a scalar it's value must be -1. + if (literal.shape().rank() == 0) { + int64 shape_val; + TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val)); + if (shape_val != -1) { + return errors::InvalidArgument( + "Cannot convert value to PartialTensorShape: ", shape_val); + } + *shape = PartialTensorShape(); // Shape with unknown rank. + return Status::OK(); + } + std::vector dims; + TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims)); + *shape = PartialTensorShape(dims); + return Status::OK(); +} + Status XlaOpKernelContext::InputList(absl::string_view name, std::vector* handles, std::vector* shapes) { @@ -353,8 +375,8 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); if (!variable->initialized()) { - return errors::InvalidArgument("Read of uninitialized variable ", - variable->name()); + return errors::FailedPrecondition("Read of uninitialized variable ", + variable->name()); } if (variable->type() != type) { return errors::InvalidArgument( @@ -446,6 +468,16 @@ void XlaOpKernelContext::SetOutputExpression(int index, } } +xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) { + xla::PrimitiveType type; + Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type); + if (!status.ok()) { + SetStatus(status); + return xla::PRIMITIVE_TYPE_INVALID; + } + return type; +} + void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { SetOutputExpression( index, @@ -456,6 +488,11 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { SetOutputExpression(index, XlaExpression::Constant(constant)); } +void XlaOpKernelContext::SetTensorListOutput(int index, + const xla::XlaOp& handle) { + SetOutputExpression(index, XlaExpression::TensorList(handle)); +} + void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { SetOutputExpression(index, XlaExpression::Resource(resource)); } @@ -497,6 +534,7 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, handle = xla::Reshape(handle, xla::AsInt64Slice(representation_shape.dimensions())); } + variable->SetRepresentationShape(representation_shape); return variable->SetValue(handle); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 1858844bc05a6e12abbf07af83cad816590ddd03..cc2d5e8de3eb020ba41dfed7d730b48cd0534b4c 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -138,6 +138,10 @@ class XlaOpKernelContext { // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); + // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1 + // into a PartialTensorShape. + Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape); + // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. @@ -155,6 +159,11 @@ class XlaOpKernelContext { return context_->expected_output_dtype(index); } + // Returns the type of output `index` as an xla::PrimitiveType. If the type + // is not representable as an XLA type, sets an error status and returns + // xla::PRIMITIVE_TYPE_INVALID. + xla::PrimitiveType output_xla_type(int index); + // Sets output `index` to the XlaOp `handle`. // All outputs should be set using SetOutput and SetConstantOutput, not // via the underlying OpKernelContext. @@ -168,6 +177,9 @@ class XlaOpKernelContext { // Returns an XlaExpression describing the value of 'index'. void SetOutputExpression(int index, const XlaExpression& expression); + // Sets output `index` to the Tensor List `handle`. + void SetTensorListOutput(int index, const xla::XlaOp& handle); + // Status handling. void SetStatus(const Status& status) { context_->SetStatus(status); } Status status() { return context_->status(); } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 14237df69081016817fbd1a5332f22996e7f264d..26314034a18b2a77a3529f0c1af242e29ec69902 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -73,6 +73,11 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible allow_resource_types settings."; return false; } + if (x.allow_variant_types != y.allow_variant_types) { + LOG(WARNING) << "Registrations of " << x.name + << " have incompatible allow_variant_types settings."; + return false; + } if (!x.has_device_whitelist && !y.has_device_whitelist) { LOG(WARNING) << "Duplicate registrations of " << x.name << "with no device whitelists."; @@ -289,6 +294,9 @@ void XlaOpRegistry::RegisterCompilationKernels() { if (op_registration->allow_resource_types) { allowed_values->add_type(DT_RESOURCE); } + if (op_registration->allow_variant_types) { + allowed_values->add_type(DT_VARIANT); + } // Don't build KernelDefs that have unsatisfiable type constraints. if (allowed_values->type().empty()) { unsatisfiable_type_constraint = true; @@ -485,6 +493,11 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { return *this; } +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowVariantTypes() { + registration_->allow_variant_types = true; + return *this; +} + XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( absl::string_view attr_name, DataType allowed) { std::set& types = diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 0bdd4a1085445420a5147756daac4a54f4725f11..c5e078a02d1ca6fdd8405ae6556a5205e387421e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -47,13 +47,14 @@ extern const char* const DEVICE_XLA_GPU; constexpr std::array kFloatTypes = { {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; -constexpr std::array kNumericTypes = { +constexpr std::array kNumericTypes = { {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF, - DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}}; + DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { +constexpr std::array kCpuAllTypes = { {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, - DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_BOOL}}; constexpr std::array kGpuAllTypes = { {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, @@ -211,6 +212,10 @@ class XlaOpRegistry { // allow DT_RESOURCE. bool allow_resource_types = false; + // Should we allow variant types for type attributes? Used by While to + // allow TensorList which is of type DT_VARIANT. + bool allow_variant_types = false; + // Mapping from attribute name to a list of supported types. std::unordered_map> type_constraints; @@ -232,9 +237,9 @@ class XlaOpRegistry { // Returns true if registrations x and y can both be added to the registry. // This is always the case if they refer to different ops. If they refer to - // the same op name, they must: have the same values for compilation_only and - // allow_resource_types; use a device_whitelist; and their - // whitelists must not intersect. + // the same op name, they must: have the same values for compilation_only, + // allow_resource_types and allow_variant_types; use a device_whitelist; and + // their whitelists must not intersect. static bool IsCompatible(const OpRegistration& x, const OpRegistration& y); static Status CompileTimeConstantInputs(const NodeDef& node_def, @@ -292,6 +297,9 @@ class XlaOpRegistrationBuilder { // Allow DT_RESOURCE types for type parameters. XlaOpRegistrationBuilder& AllowResourceTypes(); + // Allow DT_VARIANT types for type parameters. + XlaOpRegistrationBuilder& AllowVariantTypes(); + // Mark 'input_name' as an argument whose value must be known at compile-time. XlaOpRegistrationBuilder& CompileTimeConstantInput( absl::string_view input_name); diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 736588bb8b89ba756cdce77eeebff8d1fcf4774c..ab3a5bdd9bc580c16d65d35c3be3ba8204511f83 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -86,6 +86,12 @@ class XlaResource { // variables have new values that need to be written back. const xla::XlaOp& initial_value() const { return initial_value_; } + // An xla shape that indicates how this resource variable is represented on + // device. + const absl::optional& representation_shape() const { + return representation_shape_; + } + // A variable is initialized if it has a value. bool initialized() const { return value_.valid(); } @@ -100,6 +106,11 @@ class XlaResource { // Sets the current value of the resource to an all-zero value. Status SetZeroValue(xla::XlaBuilder* builder); + // Sets the representational shape of the resource on device. + void SetRepresentationShape(const xla::Shape& shape) { + representation_shape_ = absl::make_optional(shape); + } + // Looks up the gradient for `source`, or creates it if it does not already // exist. The call target must be an initialized TensorArray resource. A // TensorArray can have multiple named gradients; see the operator @@ -160,6 +171,10 @@ class XlaResource { xla::XlaOp value_; xla::XlaOp initial_value_; + // An xla shape that indicates how this resource variable is represented on + // device. + absl::optional representation_shape_; + int64 max_array_size_ = -1; bool tensor_array_multiple_writes_aggregate_ = false; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 4360e0857964b0ac63fc887e269b04a4b00d854a..ee6f7d5956ede4af99498ca0df5de47150cc5e4d 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -109,7 +109,7 @@ cc_library( name = "status_macros", srcs = ["status_macros.cc"], hdrs = ["status_macros.h"], - visibility = [":friends"], + visibility = ["//visibility:public"], deps = [ ":statusor", ":types", @@ -150,9 +150,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":status", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/stream_executor", + "//tensorflow/stream_executor/lib", ], ) @@ -194,7 +192,7 @@ cc_library( ":types", ":util", "//tensorflow/core:lib", - "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", ], ) @@ -224,6 +222,7 @@ cc_library( name = "shape_util", srcs = [ "index_util.cc", + "layout.cc", "layout_util.cc", "primitive_util.cc", "shape.cc", @@ -231,6 +230,7 @@ cc_library( ], hdrs = [ "index_util.h", + "layout.h", "layout_util.h", "primitive_util.h", "shape.h", @@ -290,6 +290,22 @@ tf_cc_test( ], ) +tf_cc_test( + name = "primitive_util_test", + srcs = ["primitive_util_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "layout_util_test", srcs = ["layout_util_test.cc"], @@ -301,6 +317,22 @@ tf_cc_test( ], ) +tf_cc_test( + name = "layout_test", + srcs = ["layout_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + tf_cc_test( name = "index_util_test", srcs = ["index_util_test.cc"], @@ -575,6 +607,7 @@ cc_library( ":types", ":util", ":xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", @@ -682,6 +715,7 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -705,8 +739,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -790,13 +824,13 @@ cc_library( "debug_options_parsers.h", ], hdrs = ["debug_options_flags.h"], + visibility = [":friends"], deps = [ ":parse_flags_from_env", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 58cc1575858201b4508d7340cb47e59c4f4c5783..529e7f77cec43f3158fcb59a53efa9a085d7422a 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -272,6 +272,15 @@ class Array { std::iota(&values_[0], &values_[0] + num_elements(), value); } + // Fills the array with a repeating sequence: + // [value, value + 1, ..., value + length - 1, value, ... ] + void FillRepeatedIota(const T& value, int64 length) { + for (int64 i = 0; i < num_elements(); i += length) { + std::iota(&values_[i], &values_[std::min(i + length, num_elements())], + value); + } + } + // Fills the array with the sequence i*multiplier for i=0,1,... void FillWithMultiples(const T& multiplier) { for (int64 i = 0; i < num_elements(); ++i) { @@ -280,11 +289,11 @@ class Array { } // Fills the array with random normal variables with the specified mean. - void FillRandom(const T& value, const double mean = 0.0, + void FillRandom(const T& stddev, const double mean = 0.0, const int seed = 12345) { std::mt19937 g(seed); std::normal_distribution distribution(mean, - static_cast(value)); + static_cast(stddev)); for (int64 i = 0; i < num_elements(); ++i) { values_[i] = static_cast(distribution(g)); } diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index fe99564d3c671cd7890e1fa26fcd2e3384972983..f5d56e8a9e1f3a05e1039f7cc90194407200f1ab 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -3,7 +3,7 @@ licenses(["notice"]) # Apache 2.0 -package(default_visibility = [":friends"]) +package(default_visibility = ["//visibility:public"]) package_group( name = "friends", @@ -170,6 +170,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", ], ) @@ -245,6 +246,7 @@ tf_cc_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 74b76f929949d3300a5d0ff45d5fa4cd9f162642..4f020bcec2756a328755d86ab04154d54f532465 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -186,7 +186,7 @@ StatusOr Client::ComputeConstant(const XlaComputation& computation, ComputeConstantGraphRequest request; *request.mutable_computation() = computation.proto(); if (output_layout != nullptr) { - *request.mutable_output_layout() = *output_layout; + *request.mutable_output_layout() = output_layout->ToProto(); } ComputeConstantResponse response; @@ -278,53 +278,51 @@ StatusOr> Client::Execute( const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { - if (execution_options != nullptr && - execution_options->device_handles_size() > 1) { - std::vector computation_instances = { - XlaComputationInstance{ - computation, - std::vector(arguments.begin(), arguments.end()), - *execution_options, execution_profile}}; - TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances)); - // The result selection is a bit hacky, but better than assuming it is - // device 0. - // - // TODO(b/118493728): Allow Execute to return one result per computation. - for (int64 i = 0; i < results.size(); i++) { - TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i])); - if (!ShapeUtil::IsEmptyTuple(shape)) { - VLOG(3) << "Fetching result from device " << i << ": " - << ShapeUtil::HumanString(shape); - return std::move(results[i]); - } + // Create an ExecutionOptions if necessary, or set its DeviceHandles. + absl::optional options_storage; + if (!execution_options || execution_options->device_handles().empty()) { + if (execution_options) { + options_storage.emplace(*execution_options); + } else { + options_storage.emplace(CreateDefaultExecutionOptions()); } - TF_RET_CHECK(!results.empty()); - VLOG(1) << "Defaulting to device 0 result"; - return std::move(results[0]); - } - - // The argument shapes affect how the computation is compiled. - std::vector arg_shapes(arguments.size()); - for (int i = 0; i < arguments.size(); i++) { - TF_ASSIGN_OR_RETURN(arg_shapes[i], GetShape(*arguments[i])); - } - - TF_ASSIGN_OR_RETURN(auto handle, - Compile(computation, arg_shapes, execution_options)); - - TF_ASSIGN_OR_RETURN(auto result, - Execute(handle, arguments, execution_profile)); - - if (execution_profile != nullptr) { - if (VLOG_IS_ON(1)) { - TF_ASSIGN_OR_RETURN( - auto execution_stats, - ExecutionStatsAsString(computation, *execution_profile)); - VLOG(1) << execution_stats; + execution_options = &*options_storage; + + TF_ASSIGN_OR_RETURN(auto device_handles, + GetDeviceHandles(/*device_count=*/1)); + TF_RET_CHECK(!device_handles.empty()); + *options_storage->add_device_handles() = std::move(device_handles[0]); + } + + std::vector computation_instances = { + XlaComputationInstance{ + computation, + std::vector(arguments.begin(), arguments.end()), + *execution_options, execution_profile}}; + + // Instead of invoking Compile() and Execute(), invoke + // Service::ExecuteParallel() to execute our one computation. Compile() + // caches the executable forever, which isn't what we want. + VLOG(1) << "Making ExecuteParallel request: " + << execution_options->DebugString(); + TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances)); + VLOG(1) << "ExecuteParallel request done."; + + // The result selection is a bit hacky, but better than assuming it is + // device 0. + // + // TODO(b/118493728): Allow Execute to return one result per computation. + for (int64 i = 0; i < results.size(); i++) { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i])); + if (!ShapeUtil::IsEmptyTuple(shape)) { + VLOG(3) << "Fetching result from device " << i << ": " + << ShapeUtil::HumanString(shape); + return std::move(results[i]); } } - - return std::move(result); + TF_RET_CHECK(!results.empty()); + VLOG(1) << "Defaulting to device 0 result"; + return std::move(results[0]); } StatusOr>> Client::ExecuteParallel( diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index d0ac4703c632e0e01d3c8911594b46fedf28930d..eff8713ac340e82ee7633f1f078334ba73b67b2f 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -52,6 +52,12 @@ class Client { // need to live beyond this call.) // * If execution_options.device_handles should be empty. If you need // non-empty device handles, call 'Execute' instead. + // + // TODO(b/122731460): This call caches the resulting Executable in the Service + // *forever*. If you're only going to run the computation once, you may want + // to call the Execute(const XlaComputation&) overload. If you're going to + // run the computation more than once but you want control over when the + // Executable is unloaded, use the LocalClient API. StatusOr Compile( const XlaComputation& computation, absl::Span argument_shapes, @@ -76,6 +82,10 @@ class Client { // device is chosen by the service. // * If execution_profile is not nullptr then the pointed-to ExecutionProfile // will be filled with profile data from the execution. + // + // TODO(b/122731460): The given computation is compiled and then thrown away + // immediately after it's run. If you want control over how long the + // resulting Executable lives, use the LocalClient API. StatusOr> Execute( const XlaComputation& computation, absl::Span arguments, diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index 27b7fa7b29206affa9f9c2e4becd9e4ea66484ab..42aae026229a49fd801cc90562fa51f604336148 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -24,12 +24,14 @@ limitations under the License. namespace xla { -LocalClientOptions::LocalClientOptions(se::Platform* platform, - int number_of_replicas, - int intra_op_parallelism_threads) +LocalClientOptions::LocalClientOptions( + se::Platform* platform, int number_of_replicas, + int intra_op_parallelism_threads, + const absl::optional>& allowed_devices) : platform_(platform), number_of_replicas_(number_of_replicas), - intra_op_parallelism_threads_(intra_op_parallelism_threads) {} + intra_op_parallelism_threads_(intra_op_parallelism_threads), + allowed_devices_(allowed_devices) {} LocalClientOptions& LocalClientOptions::set_platform(se::Platform* platform) { platform_ = platform; @@ -58,6 +60,17 @@ int LocalClientOptions::intra_op_parallelism_threads() const { return intra_op_parallelism_threads_; } +LocalClientOptions& LocalClientOptions::set_allowed_devices( + const absl::optional>& allowed_devices) { + allowed_devices_ = allowed_devices; + return *this; +} + +const absl::optional>& LocalClientOptions::allowed_devices() + const { + return allowed_devices_; +} + /* static */ ClientLibrary& ClientLibrary::Singleton() { static ClientLibrary* c = new ClientLibrary; return *c; @@ -67,9 +80,10 @@ ClientLibrary::ClientLibrary() = default; ClientLibrary::~ClientLibrary() = default; /* static */ StatusOr ClientLibrary::GetOrCreateLocalClient( - se::Platform* platform) { + se::Platform* platform, const absl::optional>& device_set) { LocalClientOptions default_options; default_options.set_platform(platform); + default_options.set_allowed_devices(device_set); return GetOrCreateLocalClient(default_options); } @@ -94,7 +108,7 @@ ClientLibrary::~ClientLibrary() = default; service_options.set_number_of_replicas(replica_count); service_options.set_intra_op_parallelism_threads( options.intra_op_parallelism_threads()); - + service_options.set_allowed_devices(options.allowed_devices()); auto instance = absl::make_unique(); TF_ASSIGN_OR_RETURN(instance->service, LocalService::NewService(service_options)); diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h index 3ad558fa532931937fab898f7b855f0a3370eaec..62d225c6c298b26bbbd248fc1f4be64fc8efcf6b 100644 --- a/tensorflow/compiler/xla/client/client_library.h +++ b/tensorflow/compiler/xla/client/client_library.h @@ -23,9 +23,11 @@ limitations under the License. #include #include +#include #include #include +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" @@ -43,9 +45,10 @@ namespace xla { // Options to configure the local client when it is created. class LocalClientOptions { public: - LocalClientOptions(se::Platform* platform = nullptr, - int number_of_replicas = 1, - int intra_op_parallelism_threads = -1); + LocalClientOptions( + se::Platform* platform = nullptr, int number_of_replicas = 1, + int intra_op_parallelism_threads = -1, + const absl::optional>& allowed_devices = absl::nullopt); // Set the platform backing the service, or nullptr for the default platform. LocalClientOptions& set_platform(se::Platform* platform); @@ -60,10 +63,17 @@ class LocalClientOptions { LocalClientOptions& set_intra_op_parallelism_threads(int num_threads); int intra_op_parallelism_threads() const; + // Sets the allowed_devices set for selectively constructing stream executors + // on the platform. + LocalClientOptions& set_allowed_devices( + const absl::optional>& allowed_devices); + const absl::optional>& allowed_devices() const; + private: se::Platform* platform_; int number_of_replicas_; int intra_op_parallelism_threads_; + absl::optional> allowed_devices_; }; class ClientLibrary { @@ -73,8 +83,11 @@ class ClientLibrary { // // platform : The platform the underlying XLA service should target. If // null then default platform is used. + // device_set: Set of device IDs for which the stream executor will be + // created, for the given platform. static StatusOr GetOrCreateLocalClient( - se::Platform* platform = nullptr); + se::Platform* platform = nullptr, + const absl::optional>& allowed_devices = absl::nullopt); static StatusOr GetOrCreateLocalClient( const LocalClientOptions& options); diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 1f594e551af381d7537e947892cbf7e0b5b3b861..ec0e08975926f36c36c854f83a40b374b12a09a4 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -58,6 +58,12 @@ const Shape* ExecutableBuildOptions::result_layout() const { return result_layout_set_ ? &result_layout_ : nullptr; } +ExecutableBuildOptions& ExecutableBuildOptions::set_num_replicas( + int num_replicas) { + num_replicas_ = num_replicas; + return *this; +} + string ExecutableBuildOptions::ToString() const { string result_layout = "nullopt"; if (result_layout_set_) { @@ -65,8 +71,9 @@ string ExecutableBuildOptions::ToString() const { } return absl::StrFormat( "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " - "generate_hlo_graph=%s}", - device_ordinal_, result_layout, debug_options().xla_generate_hlo_graph()); + "generate_hlo_graph=%s, num_replicas=%d}", + device_ordinal_, result_layout, debug_options().xla_generate_hlo_graph(), + num_replicas_); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index a58090253bfac7779e4b61bc7231a0f0d945cc00..1d85fb34304b95d1fccdb0b0d6a7a65e739fae18 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -67,12 +67,18 @@ class ExecutableBuildOptions { // debugging. string ToString() const; + // The number of replicas of this computation that are to be executed. + // Defaults to 1. + int num_replicas() const { return num_replicas_; } + ExecutableBuildOptions& set_num_replicas(int num_replicas); + private: int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; absl::optional debug_options_; DeviceMemoryAllocator* device_allocator_ = nullptr; + int num_replicas_ = 1; }; } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 41db8de29ff0085a30847ff41db4ffbfc774e2a1..c5dea5f18030f2d226c86e3408ea85b2b5989728 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -1,5 +1,7 @@ # Common computation builders for XLA. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") + licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow/compiler/xla/client:friends"]) @@ -13,9 +15,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") - # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -35,6 +34,95 @@ cc_library( ], ) +xla_test( + name = "arithmetic_test", + srcs = ["arithmetic_test.cc"], + deps = [ + ":arithmetic", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "cholesky", + srcs = ["cholesky.cc"], + hdrs = ["cholesky.h"], + deps = [ + ":math", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "cholesky_test", + srcs = ["cholesky_test.cc"], + tags = ["optonly"], + deps = [ + ":arithmetic", + ":cholesky", + ":matrix", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "comparators", + srcs = ["comparators.cc"], + hdrs = ["comparators.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +xla_test( + name = "comparators_test", + srcs = ["comparators_test.cc"], + deps = [ + ":comparators", + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:inlined_vector", + ], +) + cc_library( name = "constants", srcs = ["constants.cc"], @@ -52,7 +140,6 @@ cc_library( xla_test( name = "constants_test", srcs = ["constants_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":constants", "//tensorflow/compiler/xla:test", @@ -75,11 +162,28 @@ cc_library( ], ) +cc_library( + name = "loops", + srcs = ["loops.cc"], + hdrs = ["loops.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "math", srcs = ["math.cc"], hdrs = ["math.h"], deps = [ + ":arithmetic", ":constants", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -90,7 +194,23 @@ cc_library( xla_test( name = "math_test", srcs = ["math_test.cc"], - tags = ["enable_for_xla_interpreter"], + deps = [ + ":math", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +xla_test( + name = "math_exhaustive_test", + srcs = ["math_exhaustive_test.cc"], + shard_count = 16, deps = [ ":math", "//tensorflow/compiler/xla:literal_util", @@ -110,13 +230,18 @@ cc_library( deps = [ ":arithmetic", ":constants", + ":slicing", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -124,16 +249,19 @@ cc_library( xla_test( name = "matrix_test", srcs = ["matrix_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":matrix", ":slicing", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", ], ) @@ -172,23 +300,59 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:lib", "@com_google_absl//absl/base", ], ) cc_library( - name = "slicing", - srcs = ["slicing.cc"], - hdrs = ["slicing.h"], + name = "qr", + srcs = ["qr.cc"], + hdrs = ["qr.h"], deps = [ + ":arithmetic", + ":constants", + ":loops", + ":math", + ":matrix", + ":slicing", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "qr_test", + srcs = ["qr_test.cc"], + tags = ["optonly"], + deps = [ + ":matrix", + ":qr", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "slicing", + srcs = ["slicing.cc"], + hdrs = ["slicing.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", "@com_google_absl//absl/types:span", ], ) @@ -196,13 +360,11 @@ cc_library( xla_test( name = "slicing_test", srcs = ["slicing_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":slicing", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -214,6 +376,7 @@ cc_library( srcs = ["sorting.cc"], hdrs = ["sorting.h"], deps = [ + ":comparators", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -225,13 +388,42 @@ cc_library( xla_test( name = "sorting_test", srcs = ["sorting_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":sorting", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "quantize", + hdrs = ["quantize.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "quantize_test", + srcs = ["quantize_test.cc"], + # TODO(b/122119490): re-enable TAP after fixing. + tags = [ + "notap", + ], + deps = [ + ":quantize", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -260,46 +452,52 @@ cc_library( ) cc_library( - name = "triangular_solve", - srcs = ["triangular_solve.cc"], - hdrs = ["triangular_solve.h"], + name = "self_adjoint_eig", + srcs = ["self_adjoint_eig.cc"], + hdrs = ["self_adjoint_eig.h"], deps = [ - "//tensorflow/compiler/xla:literal", + ":arithmetic", + ":comparators", + ":constants", + ":loops", + ":math", + ":matrix", + ":slicing", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:matrix", - "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/core:lib", ], ) xla_test( - name = "triangular_solve_test", - srcs = ["triangular_solve_test.cc"], - tags = ["noasan"], # sometimes times out, http://b/78650012 + name = "self_adjoint_eig_test", + srcs = ["self_adjoint_eig_test.cc"], + blacklisted_backends = [ + "cpu", + "gpu", + ], + real_hardware_only = True, + shard_count = 10, + tags = ["optonly"], deps = [ - ":triangular_solve", + ":arithmetic", + ":constants", + ":matrix", + ":self_adjoint_eig", "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index e86c10f030f3990d67e5a6638100640f73c82307..3b875135af29f142463ffd783bfeaadc61ada1af 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -117,10 +117,70 @@ XlaOp Any(XlaOp predicates) { XlaComputation logical_or = CreateScalarOrComputation(PRED, builder); TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, builder->GetShape(predicates)); - std::vector all_dimensions(ShapeUtil::Rank(predicates_shape)); + std::vector all_dimensions(predicates_shape.rank()); std::iota(all_dimensions.begin(), all_dimensions.end(), 0); return Reduce(predicates, f, logical_or, all_dimensions); }); } +namespace { + +XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) { + XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + XlaOp init_value; + XlaComputation reducer; + if (is_min) { + init_value = MaxValue(builder, input_shape.element_type()); + reducer = CreateScalarMinComputation(input_shape.element_type(), builder); + } else { + init_value = MinValue(builder, input_shape.element_type()); + reducer = CreateScalarMaxComputation(input_shape.element_type(), builder); + } + + XlaOp input_max = Reduce(input, init_value, reducer, + /*dimensions_to_reduce=*/{axis}); + std::vector broadcast_dims(input_shape.rank() - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); + // Compute a mask that has 1s for elements equal to the maximum. + XlaOp partial_mask = + ConvertElementType(Eq(input, input_max, broadcast_dims), output_type); + + // In order to make identity elements for a bitwise And, we: + // Left shift the 1 to the leftmost bit, yielding 0x10...0 + // Arithmetic right shift the 1 back to the rightmost bit, yielding + // 0xFF...F + int32 bits_in_type = + ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1; + XlaOp shift_amount = ConstantR0WithType(builder, output_type, bits_in_type); + XlaOp full_mask = ShiftRightArithmetic( + ShiftLeft(partial_mask, shift_amount), shift_amount); + + // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its + // index. + + const int64 axis_size = ShapeUtil::GetDimension(input_shape, axis); + XlaOp iota = Iota(builder, output_type, axis_size); + XlaOp product = And(full_mask, iota, /*broadcast_dimensions=*/{axis}); + + // If there are multiple maximum elements, choose the one with the highest + // index. + return Reduce(product, MinValue(builder, output_type), + CreateScalarMaxComputation(output_type, builder), + /*dimensions_to_reduce=*/{axis}); + }); +} + +} // namespace + +XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) { + return ArgMinMax(input, output_type, axis, /*is_min=*/false); +} + +XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) { + return ArgMinMax(input, output_type, axis, /*is_min=*/true); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 632e8cc8bc64fad236a0226c6e93079aadde7050..d4a7812c441c351b121e5d72faf9642b06728b18 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -57,6 +57,14 @@ XlaComputation CreateScalarOrComputation(PrimitiveType type, // Note: if predicates is zero-sized, Any() vacuously returns false. XlaOp Any(XlaOp predicates); +// Returns the argmax of `input` along `axis`. `output_type` is the type to +// use for the output. +XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis); + +// Returns the argmin of `input` along `axis`. `output_type` is the type to +// use for the output. +XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a13839f9db89b9c07f2465867a503ef2193f8160 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +using ArithmeticTest = ClientLibraryTestBase; + +XLA_TEST_F(ArithmeticTest, ArgMinR2Axis0) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMin(x, S32, /*axis=*/0); + + std::vector expected = {0, 2, 2}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArithmeticTest, ArgMinR2Axis1) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMin(x, S32, /*axis=*/1); + + std::vector expected = {0, 1, 2}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis0) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMax(x, S32, /*axis=*/0); + + std::vector expected = {2, 0, 1}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis1) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMax(x, S32, /*axis=*/1); + + std::vector expected = {1, 0, 0}; + ComputeAndCompareR1(&builder, expected, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/xla/client/lib/cholesky.cc similarity index 57% rename from tensorflow/compiler/tf2xla/lib/cholesky.cc rename to tensorflow/compiler/xla/client/lib/cholesky.cc index 550ab5b05693b79e60e49577309328ac6846d3f9..bb41f9932d1cc62b62d37fea2c10fbfeaa0bd15e 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/xla/client/lib/cholesky.cc @@ -13,25 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/cholesky.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" -#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/errors.h" -namespace tensorflow { +namespace xla { namespace { @@ -50,70 +50,63 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::XlaOp CholeskyUnblocked(xla::XlaOp a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int n_dims = xla::ShapeUtil::Rank(a_shape); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - auto major_dims = xla::AsInt64Slice(a_shape.dimensions()) +XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int n_dims = a_shape.rank(); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + auto major_dims = AsInt64Slice(a_shape.dimensions()) .subspan( /*pos=*/0, /*len=*/n_dims - 2); - xla::XlaOp l = xla::ZerosLike(a); + XlaOp l = ZerosLike(a); // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::Shape col_shape; - xla::Shape row_shape; - for (int64 d : major_dims) { - row_shape.add_dimensions(d); - col_shape.add_dimensions(d); - } - row_shape.add_dimensions(1); - row_shape.add_dimensions(n); - row_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_row = xla::Zeros(body_builder, row_shape); - - col_shape.add_dimensions(n); - col_shape.add_dimensions(1); - col_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_col = xla::Zeros(body_builder, col_shape); - - std::vector mask_vector(n); - std::iota(mask_vector.begin(), mask_vector.end(), 0); - auto mask_range = xla::ConstantR1(body_builder, mask_vector); + auto body_fn = + [&](XlaOp i, absl::Span loop_vars, + XlaBuilder* body_builder) -> StatusOr> { + std::vector row_shape_dims(major_dims.begin(), major_dims.end()); + std::vector col_shape_dims(major_dims.begin(), major_dims.end()); + row_shape_dims.push_back(1); + row_shape_dims.push_back(n); + auto mask_zeros_row = + Zeros(body_builder, + ShapeUtil::MakeShape(a_shape.element_type(), row_shape_dims)); + + col_shape_dims.push_back(n); + col_shape_dims.push_back(1); + auto mask_zeros_col = + Zeros(body_builder, + ShapeUtil::MakeShape(a_shape.element_type(), col_shape_dims)); + auto mask_range_row = - xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); + Iota(body_builder, ShapeUtil::MakeShape(S32, row_shape_dims), + /*iota_dimension=*/n_dims - 1); auto mask_range_col = - xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); + Iota(body_builder, ShapeUtil::MakeShape(S32, col_shape_dims), + /*iota_dimension=*/n_dims - 2); auto body_a = loop_vars[0]; auto body_l = loop_vars[1]; // row = l[..., i, :i] // select the whole i-th row, then mask out all columns past i-1 - auto zero = xla::ConstantR0(body_builder, 0); + auto zero = ConstantR0(body_builder, 0); auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n}); - auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); + auto row = Select(Ge(mask_range_row, i), mask_zeros_row, l_i); // a[..., i, i] auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); // np.dot(row, np.swapaxes(row, -1, -2)) auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) - auto l_ii = - xla::Pow(a_ii - diag_dot, - FloatLiteral(body_builder, a_shape.element_type(), 0.5)); + auto l_ii = Sqrt(a_ii - diag_dot); // a[..., i+1:, i] // select the whole i-th column, then mask out all rows above i+1 auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1}); - auto a_ip1i = - xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i); + auto a_ip1i = Select(Le(mask_range_col, i), mask_zeros_col, a_0i); // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / // l[..., i, i] @@ -122,8 +115,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // r.T) auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision); // np.dot(l[..., i+1:, :i], r.T) - auto dot_ip1 = - xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); + auto dot_ip1 = Select(Le(mask_range_col, i), mask_zeros_col, dot); body_l = DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i}); @@ -131,12 +123,12 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // column assign will wrap around and overwrite the diagonal assign. body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i}); - return std::vector{body_a, body_l}; + return std::vector{body_a, body_l}; }; TF_ASSIGN_OR_RETURN( auto cholesky_while, - XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); + ForEachIndex(n, S32, body_fn, {a, l}, "unblocked", builder)); return cholesky_while[1]; }); @@ -144,34 +136,41 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, } // namespace -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(a_shape); +XlaOp Cholesky(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int ndims = a_shape.rank(); if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to Cholesky must have rank >= 2: ", ndims); + return InvalidArgument( + "Argument to Cholesky must have rank >= 2; shape was %s", + a_shape.ToString()); + } + + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + if (n != ShapeUtil::GetDimension(a_shape, -2)) { + return InvalidArgument( + "Argument to Cholesky must be batched square matrices; got shape %s", + ShapeUtil::HumanString(a_shape)); } - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); + if (primitive_util::IsComplexType(a_shape.element_type())) { + return Unimplemented( + "Complex types are not implemented in Cholesky; got shape %s", + ShapeUtil::HumanString(a_shape)); } if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to Cholesky must be >= 1; got ", block_size); + return InvalidArgument( + "block_size argument to Cholesky must be >= 1; got %d", block_size); } // Blocked left-looking Cholesky factorization. // Algorithm 1 from // Haidar, Azzam, et al. "High-performance Cholesky factorization for // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::XlaOp l = xla::ZerosLike(a); + XlaOp l = ZerosLike(a); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); if (i > 0) { @@ -194,12 +193,12 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, // l[i+k:, i:i+k] = // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k}); - auto update = TriangularSolve(factorized, panel, - /*left_side=*/false, - /*lower=*/true, - /*transpose_a=*/true, - /*conjugate_a=*/false, - /*block_size=*/block_size); + auto update = + TriangularSolve(factorized, panel, + /*left_side=*/false, + /*lower=*/true, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); l = UpdateSliceInMinorDims(l, update, {i + k, i}); } } @@ -207,4 +206,4 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, }); } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/xla/client/lib/cholesky.h similarity index 87% rename from tensorflow/compiler/tf2xla/lib/cholesky.h rename to tensorflow/compiler/xla/client/lib/cholesky.h index 9a561c34b92ee45059f2a05336e682838f8e36e2..0bae26837c0f14dd0cfab82cf426becc787ec11c 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/xla/client/lib/cholesky.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Computes the Cholesky decompositions of a batch of symmetric positive // definite matrices. @@ -34,6 +34,6 @@ xla::XlaOp Cholesky( xla::XlaOp a, int64 block_size = 256, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ diff --git a/tensorflow/compiler/xla/client/lib/cholesky_test.cc b/tensorflow/compiler/xla/client/lib/cholesky_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..095dd4fbf8b7c90047c4428b50c626c16e9c1e94 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/cholesky_test.cc @@ -0,0 +1,166 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/cholesky.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { + +using xla::int64; + +using CholeskyTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(CholeskyTest, Simple) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a, /*block_size=*/2); + + xla::Array2D expected({ + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(CholeskyTest, Simple2) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a); + + xla::Array2D expected( + {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(CholeskyTest, SimpleBatched) { + xla::XlaBuilder builder(TestName()); + + xla::Array3D a_vals({ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }, + }); + + xla::XlaOp a; + auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a); + + xla::Array3D expected({ + { + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}, + }); + + ComputeAndCompareR3(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +using CholeskyTestCase = std::tuple; + +class RandomCholeskyTest + : public xla::ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(RandomCholeskyTest, Random) { + xla::XlaBuilder builder(TestName()); + + auto test_params = GetParam(); + std::vector dimensions = {std::get<0>(test_params), + std::get<1>(test_params), + std::get<1>(test_params)}; + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, dimensions); + TF_ASSERT_OK_AND_ASSIGN( + auto literal, + xla::LiteralUtil::CreateRandomLiteral(shape, 0.0, 1.0)); + + auto input = xla::Parameter(&builder, 0, shape, "input"); + // Form a random positive definite matrix. + auto matrix = xla::BatchDot(input, TransposeInMinorDims(input), + xla::PrecisionConfig::HIGHEST); + + auto cholesky = xla::Cholesky(matrix, /*block_size=*/4); + + // Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0 + auto verification = xla::BatchDot(cholesky, TransposeInMinorDims(cholesky), + xla::PrecisionConfig::HIGHEST); + auto delta = matrix - verification; + xla::Reduce(delta * delta, xla::ConstantR0(&builder, 0.0), + CreateScalarAddComputation(xla::F32, &builder), {0, 1, 2}); + + TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal)); + ComputeAndCompareR0(&builder, 0.0, {input_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest, + ::testing::Values(CholeskyTestCase{1, 1}, + CholeskyTestCase{1, 2}, + CholeskyTestCase{10, 5}, + CholeskyTestCase{2, 20})); + +} // namespace diff --git a/tensorflow/compiler/xla/client/lib/comparators.cc b/tensorflow/compiler/xla/client/lib/comparators.cc new file mode 100644 index 0000000000000000000000000000000000000000..c620c9841a5146618e3a142adeb3fe2da525950a --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/comparators.cc @@ -0,0 +1,159 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/comparators.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +using XlaOpGenerator = XlaOp (*)(const XlaOp&, const XlaOp&, + absl::Span); + +XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value, + int64 bit_width) { + PrimitiveType signed_type; + PrimitiveType unsigned_type; + XlaOp max_value; + switch (bit_width) { + case 16: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S16; + unsigned_type = U16; + break; + case 32: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S32; + unsigned_type = U32; + break; + case 64: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S64; + unsigned_type = U64; + break; + default: + return value.builder()->ReportError( + InvalidArgument("Invalid bit width %lld for Comparator floating " + "point parameter.", + bit_width)); + } + // Switch from a floating point value to a integer value in such a way that + // when using the integer value to compare, we get the same result for normal + // values, and -Nan is treated as the smallest value, and Nan is treated as + // the largest value. + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? numeric_limits::max() - x : x; + // then y is ordered as an int32 such that finite values have the obvious + // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning + // and end of the ordering. + // Note that in order to avoid -x to overflow, we calculate + // numeric_limits::max() - x as unsigned, and then convert back to + // signed. + auto signed_value = BitcastConvertType(value, signed_type); + auto unsigned_value = BitcastConvertType(value, unsigned_type); + auto flipped_value = + BitcastConvertType(Sub(max_value, unsigned_value), signed_type); + auto is_negative = Lt(signed_value, Zero(value.builder(), signed_type)); + return Select(is_negative, flipped_value, signed_value); +} + +XlaComputation CreateScalarComparisonComputation( + const string& name, const std::vector& operand_types, + XlaBuilder* builder, XlaOpGenerator generator) { + // Create a default computation where we compare only the first two + // parameters of type 'operand_types[0]'. + auto b = builder->CreateSubBuilder(name); + if (operand_types.empty()) { + b->ReportError(InvalidArgument("operand_types should not be empty")); + return b->BuildAndNoteError(); + } + + int64 parameter_count = 0; + XlaOp first_lhs_param; + XlaOp first_rhs_param; + + // For each type in 'operand_types' we create two parameters of this type. The + // idea is that this computation can be used by n-ary Sort, and potentially + // should support comparing also the other operands of sort. In this default + // computation, however, we will not actually use any parameters except the + // first two. + for (auto operand_type : operand_types) { + auto scalar_shape = ShapeUtil::MakeShape(operand_type, {}); + auto lhs_param = Parameter(b.get(), parameter_count * 2, scalar_shape, + absl::StrCat("p.", parameter_count, ".lhs")); + auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape, + absl::StrCat("p.", parameter_count, ".rhs")); + if (parameter_count == 0) { + first_lhs_param = lhs_param; + first_rhs_param = rhs_param; + } + ++parameter_count; + } + if (primitive_util::IsFloatingPointType(operand_types[0])) { + PrimitiveType compare_type = operand_types[0]; + // Special-case handling for BF16. We currently do not support direct + // comparisons with BF16, so we convert to F32 and then use the F32 + // comparison logic. + if (compare_type == BF16) { + compare_type = F32; + first_lhs_param = ConvertElementType(first_lhs_param, F32); + first_rhs_param = ConvertElementType(first_rhs_param, F32); + } + int64 bit_width = primitive_util::BitWidth(compare_type); + first_lhs_param = + BitcastConvertFloatingPointToIntegral(first_lhs_param, bit_width); + first_rhs_param = + BitcastConvertFloatingPointToIntegral(first_rhs_param, bit_width); + } + generator(first_lhs_param, first_rhs_param, {}); + return b->BuildAndNoteError(); +} +} // namespace + +// Creates a scalar less-than computation and returns it. +XlaComputation CreateScalarLtComputation( + const std::vector& operand_types, XlaBuilder* builder) { + return CreateScalarComparisonComputation("compare-less-than", operand_types, + builder, Lt); +} + +// Creates a scalar greater-than computation and returns it. +XlaComputation CreateScalarGtComputation( + const std::vector& operand_types, XlaBuilder* builder) { + return CreateScalarComparisonComputation("compare-greater-than", + operand_types, builder, Gt); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/comparators.h b/tensorflow/compiler/xla/client/lib/comparators.h new file mode 100644 index 0000000000000000000000000000000000000000..cbcfc227dd495537f59bf0a9090bad8ade15da62 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/comparators.h @@ -0,0 +1,47 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Creates a scalar less-than computation and returns it. The created +// computation has 2 * 'operand_types.size()' many parameters, where parameters +// 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The +// computation compares the first two parameters. For floating point types, a +// total order is created where +// -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN +XlaComputation CreateScalarLtComputation( + const std::vector& operand_types, XlaBuilder* builder); + +// Creates a scalar greater-than computation and returns it. The created +// computation has 2 * 'operand_types.size()' many parameters, where parameters +// 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The +// computation compares the first two parameters. For floating point types, a +// total order is created where +// NaN > infinity > ... > 0 > -0 > ... > -infinity > -NaN +XlaComputation CreateScalarGtComputation( + const std::vector& operand_types, XlaBuilder* builder); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ diff --git a/tensorflow/compiler/xla/client/lib/comparators_test.cc b/tensorflow/compiler/xla/client/lib/comparators_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..598956803b34702b1e095a342648d348fa350b29 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/comparators_test.cc @@ -0,0 +1,149 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/comparators.h" + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class ComparatorsTest : public ClientLibraryTestBase { + public: + ComparatorsTest() : builder_(TestName()) {} + XlaBuilder* builder() { return &builder_; } + + private: + XlaBuilder builder_; +}; + +template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> +void BuildComparatorAndComparisons(ComparatorsTest* test, + bool compare_less_than, + absl::InlinedVector* expected) { + auto compare = compare_less_than + ? CreateScalarLtComputation({type}, test->builder()) + : CreateScalarGtComputation({type}, test->builder()); + + auto negative_nan = ConstantR0( + test->builder(), -T(std::numeric_limits::quiet_NaN())); + auto positive_nan = ConstantR0(test->builder(), + T(std::numeric_limits::quiet_NaN())); + auto negative_zero = ConstantR0(test->builder(), T(-0.)); + auto positive_zero = ConstantR0(test->builder(), T(0.)); + auto negative_infinity = MinValue(test->builder(), type); + auto positive_infinity = MaxValue(test->builder(), type); + + // List the values in the expected sorting order from smallest to largest. + std::vector all_constants{negative_nan, negative_infinity, + negative_zero, positive_zero, + positive_infinity, positive_nan}; + + // Do pairwise comparisons. + std::vector all_comparisons; + for (const XlaOp& lhs_constant : all_constants) { + for (const XlaOp& rhs_constant : all_constants) { + all_comparisons.push_back(Broadcast( + Call(test->builder(), compare, {lhs_constant, rhs_constant}), {1})); + } + } + + // Concantenate the comparison results. + ConcatInDim(test->builder(), all_comparisons, 0); + + // If we use less-than comparisons, we expect the comparison to result in true + // if the lhs value to be compared appears earlier in 'all_constants' than the + // rhs value. Likewise, if we use greater-than comparisons, we expect the + // comparison to return true if the rhs value appears earlier in + // 'all_constants' than the lhs value. + expected->clear(); + for (int i = 0; i < all_constants.size(); ++i) { + for (int j = 0; j < all_constants.size(); ++j) { + expected->push_back(compare_less_than ? i < j : i > j); + } + } +} + +XLA_TEST_F(ComparatorsTest, CompareLtBF16) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/true, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareGtBF16) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/false, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareLtF16) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/true, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareGtF16) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/false, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareLtF32) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/true, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareGtF32) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/false, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareLtF64) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/true, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareGtF64) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/false, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc index 1ada7b4a964ccf7ca400b937abbe425bef083468..6bd56a8df0a5d0417f747a158664ed0daa8a7b40 100644 --- a/tensorflow/compiler/xla/client/lib/constants.cc +++ b/tensorflow/compiler/xla/client/lib/constants.cc @@ -80,6 +80,24 @@ XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) { } } +XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) { + switch (type) { + case F16: + return ConstantR0(builder, + std::numeric_limits::min()); + case BF16: + return ConstantR0(builder, bfloat16::min_positive_normal()); + case F32: + return ConstantR0(builder, std::numeric_limits::min()); + case F64: + return ConstantR0(builder, std::numeric_limits::min()); + default: + return builder->ReportError( + InvalidArgument("Invalid type for MinPositiveNormalValue (%s).", + PrimitiveType_Name(type))); + } +} + XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) { return ConstantLiteral(builder, LiteralUtil::MaxValue(type)); } @@ -100,4 +118,28 @@ XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) { } } +XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + switch (type) { + case F16: + return ConstantR0( + builder, Eigen::NumTraits::quiet_NaN()); + case BF16: + return ConstantR0( + builder, bfloat16(std::numeric_limits::quiet_NaN())); + case F32: + return ConstantR0(builder, + std::numeric_limits::quiet_NaN()); + case F64: + return ConstantR0(builder, + std::numeric_limits::quiet_NaN()); + default: + return InvalidArgument( + "Operand to NanValue was %s, but must be a real-valued " + "floating-point type.", + PrimitiveType_Name(type)); + } + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index 81624614c1e3599dfe116eb61d9e2edcd5230684..47b8f1b44ffa12b2b15be0e865d693a709962e6e 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -56,6 +56,8 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { return ConstantR0(builder, static_cast(value)); case C64: return ConstantR0(builder, static_cast(value)); + case C128: + return ConstantR0(builder, static_cast(value)); case U8: return ConstantR0(builder, static_cast(value)); case U32: @@ -88,6 +90,27 @@ XlaOp ScalarLike(XlaOp prototype, T value) { }); } +// Returns an array or scalar containing copies of `value` cast to the same +// run-type type as `prototype` and broadcast to the same dimensions as +// `prototype`. +// +// If `prototype` is not a scalar or array, returns an error. +template +XlaOp FullLike(XlaOp prototype, T value) { + XlaBuilder* builder = prototype.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); + if (ShapeUtil::IsScalar(shape) || shape.IsArray()) { + return Broadcast(ScalarLike(prototype, value), shape.dimensions()); + } else { + return InvalidArgument( + "Prototype shape for BroadcastConstantLike must be a scalar or " + "array, but was %s", + shape.ToString()); + } + }); +} + // Returns a scalar with value '0' of 'type'. XlaOp Zero(XlaBuilder* builder, PrimitiveType type); @@ -112,6 +135,9 @@ XlaOp MinValue(XlaBuilder* builder, PrimitiveType type); // point type, this is equal to -MaxFiniteValue(). XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type); +// Returns the minimum positive normal value for floating-point type `type`. +XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type); + // Returns the maximum representable finite or infinite value for 'type'. // Returns 'inf' for floating-point types. XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type); @@ -119,6 +145,9 @@ XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type); // Returns the maximum representable finite value for 'type'. XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type); +// Returns a nan for the given type. Only valid for real-valued fp types. +XlaOp NanValue(XlaBuilder* builder, PrimitiveType type); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ diff --git a/tensorflow/compiler/xla/client/lib/constants_test.cc b/tensorflow/compiler/xla/client/lib/constants_test.cc index f4320f65c1f76d4d4c384110b39d6606773aaf01..180175b7495b32250af8ae77c8c7fba804703885 100644 --- a/tensorflow/compiler/xla/client/lib/constants_test.cc +++ b/tensorflow/compiler/xla/client/lib/constants_test.cc @@ -155,5 +155,12 @@ XLA_TEST_F(ConstantsTest, MaxValueF32) { {}); } +XLA_TEST_F(ConstantsTest, NanValueF32) { + XlaBuilder builder(TestName()); + NanValue(&builder, F32); + ComputeAndCompareR0(&builder, std::numeric_limits::quiet_NaN(), + {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/xla/client/lib/loops.cc similarity index 50% rename from tensorflow/compiler/tf2xla/lib/while_loop.cc rename to tensorflow/compiler/xla/client/lib/loops.cc index 594ab1dfd0700f47501712183f6efe62d17e15e7..721f987628a8ac7da3f3f872939c3f0457d6bbe2 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/xla/client/lib/loops.cc @@ -13,44 +13,43 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" + +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -namespace tensorflow { +namespace xla { -xla::StatusOr> XlaWhileLoop( - const LoopConditionFunction& condition_function, - const LoopBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder) { +StatusOr> WhileLoopHelper( + const WhileLoopHelperConditionFunction& condition_function, + const WhileLoopHelperBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder) { int arity = initial_values.size(); - std::vector var_shapes; + std::vector var_shapes; var_shapes.reserve(arity); - for (const xla::XlaOp& input : initial_values) { + for (const XlaOp& input : initial_values) { TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); var_shapes.push_back(std::move(shape)); } - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); + Shape tuple_shape = ShapeUtil::MakeTupleShape(var_shapes); // Unpacks a tuple into its component parts. - auto unpack_tuple = [](xla::XlaOp tuple, int arity, - xla::XlaBuilder* builder) { - std::vector elements(arity); + auto unpack_tuple = [](XlaOp tuple, int arity, XlaBuilder* builder) { + std::vector elements(arity); for (int i = 0; i < arity; ++i) { - elements[i] = xla::GetTupleElement(tuple, i); + elements[i] = GetTupleElement(tuple, i); } return elements; }; // Build the condition. - std::unique_ptr cond_builder = + std::unique_ptr cond_builder = builder->CreateSubBuilder(absl::StrCat(name, "_condition")); { - auto parameter = - xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); + auto parameter = Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), @@ -60,11 +59,10 @@ xla::StatusOr> XlaWhileLoop( TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); // Build the body. - std::unique_ptr body_builder = + std::unique_ptr body_builder = builder->CreateSubBuilder(absl::StrCat(name, "_body")); { - auto parameter = - xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter"); + auto parameter = Parameter(body_builder.get(), 0, tuple_shape, "parameter"); TF_ASSIGN_OR_RETURN( auto result, @@ -72,56 +70,54 @@ xla::StatusOr> XlaWhileLoop( body_builder.get())); TF_RET_CHECK(result.size() == initial_values.size()); - xla::Tuple(body_builder.get(), result); + Tuple(body_builder.get(), result); } TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); - auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values)); + auto outputs = While(cond, body, Tuple(builder, initial_values)); return unpack_tuple(outputs, arity, builder); } -xla::StatusOr> XlaForEachIndex( - int64 num_iterations, xla::PrimitiveType num_iterations_type, +StatusOr> ForEachIndex( + int64 num_iterations, PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder) { - auto while_cond_fn = - [&](absl::Span values, - xla::XlaBuilder* cond_builder) -> xla::StatusOr { - return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type, - num_iterations)); + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder) { + auto while_cond_fn = [&](absl::Span values, + XlaBuilder* cond_builder) -> StatusOr { + return Lt(values[0], ConstantR0WithType(cond_builder, num_iterations_type, + num_iterations)); }; - auto while_body_fn = [&](absl::Span values, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::XlaOp iteration = values[0]; + auto while_body_fn = + [&](absl::Span values, + XlaBuilder* body_builder) -> StatusOr> { + XlaOp iteration = values[0]; - std::vector updated_values; + std::vector updated_values; updated_values.reserve(values.size()); - updated_values.push_back(xla::Add( + updated_values.push_back(Add( iteration, - xla::ConstantLiteral(body_builder, - xla::LiteralUtil::One(num_iterations_type)))); + ConstantLiteral(body_builder, LiteralUtil::One(num_iterations_type)))); values.remove_prefix(1); - TF_ASSIGN_OR_RETURN(std::vector body_outputs, + TF_ASSIGN_OR_RETURN(std::vector body_outputs, body_function(iteration, values, body_builder)); updated_values.insert(updated_values.end(), body_outputs.begin(), body_outputs.end()); return updated_values; }; - std::vector values; + std::vector values; values.reserve(initial_values.size() + 1); - values.push_back(xla::ConstantLiteral( - builder, xla::LiteralUtil::Zero(num_iterations_type))); + values.push_back( + ConstantLiteral(builder, LiteralUtil::Zero(num_iterations_type))); values.insert(values.end(), initial_values.begin(), initial_values.end()); - TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, - name, builder)); + TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn, + values, name, builder)); values.erase(values.begin(), values.begin() + 1); return values; } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/xla/client/lib/loops.h similarity index 62% rename from tensorflow/compiler/tf2xla/lib/while_loop.h rename to tensorflow/compiler/xla/client/lib/loops.h index f2134bb4495a12b8342961d96f70e7737f816c7d..e11de59493e9c1de51fbdb6c45dab6d82b85a62a 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/xla/client/lib/loops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ #include #include @@ -25,19 +25,18 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" -namespace tensorflow { +namespace xla { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function(absl::Span, - xla::XlaBuilder*)> - LoopConditionFunction; +typedef std::function(absl::Span, XlaBuilder*)> + WhileLoopHelperConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. -typedef std::function>( - absl::Span, xla::XlaBuilder*)> - LoopBodyFunction; +typedef std::function>(absl::Span, + XlaBuilder*)> + WhileLoopHelperBodyFunction; // Helper function for building an XLA while loop, where the values carried by // the loop are a tuple of values, e.g., (a, b, c): @@ -47,27 +46,27 @@ typedef std::function>( // init: (a, b, c) // ) // 'name' is a descriptive name for the loop. -xla::StatusOr> XlaWhileLoop( - const LoopConditionFunction& condition_function, - const LoopBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder); +StatusOr> WhileLoopHelper( + const WhileLoopHelperConditionFunction& condition_function, + const WhileLoopHelperBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. // // The body function (ForEachIndexBodyFunction) takes as input a pair of // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. -typedef std::function>( - xla::XlaOp, absl::Span, xla::XlaBuilder*)> +typedef std::function>( + XlaOp, absl::Span, XlaBuilder*)> ForEachIndexBodyFunction; -xla::StatusOr> XlaForEachIndex( - int64 num_iterations, xla::PrimitiveType num_iterations_type, +StatusOr> ForEachIndex( + int64 num_iterations, PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder); + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 36fdda39b4124b9100c6054160f9c17bdf787d6f..f3fe3d0b5ebaabdc762c811027b85444db7b0d56 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -13,59 +13,103 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// This macro is required to make MSVC defines math constants in math.h +#define _USE_MATH_DEFINES +#include + #include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" namespace xla { -XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); } +// TODO(jlebar): Use this function in more places in this file to restrict the +// domain of other functions. +static Status EnsureOperandIsRealFp(absl::string_view op_name, XlaOp operand) { + auto& b = *operand.builder(); + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); + auto elem_ty = shape.element_type(); + if (!primitive_util::IsFloatingPointType(elem_ty)) { + return InvalidArgument( + "Operands to %s must be real-valued floating-point, but got %s", + op_name, PrimitiveType_Name(elem_ty)); + } + return Status::OK(); +} -XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); } +XlaOp IsPosInf(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsPosInf", operand)); + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); + // Note that this is only correct for floating-point types. If we wanted it + // to be correct for all types, we'd need to Gt(MaxFiniteValue). + return Eq(operand, MaxValue(&b, shape.element_type())); + }); +} -XlaOp Square(XlaOp operand) { return operand * operand; } +XlaOp IsNegInf(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegInf", operand)); + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); + // Note that this is only correct for floating-point types. If we wanted it + // to be correct for all types, we'd need to Lt(MinFiniteValue). + return Eq(operand, MinValue(&b, shape.element_type())); + }); +} -XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; } +XlaOp IsInf(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsInf", operand)); + return IsPosInf(Abs(operand)); + }); +} -namespace { +XlaOp IsNan(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNan", operand)); + return Ne(operand, operand); + }); +} -// Polynomials for computing erf/erfc. Originally from cephes. -// Note we use float for compatibility across devices, at the cost of some -// precision for 64 bit computations. -// -// Coefficients are in descending order. -std::array kErfcPCoefficient = { - 2.46196981473530512524E-10, 5.64189564831068821977E-1, - 7.46321056442269912687E0, 4.86371970985681366614E1, - 1.96520832956077098242E2, 5.26445194995477358631E2, - 9.34528527171957607540E2, 1.02755188689515710272E3, - 5.57535335369399327526E2}; -std::array kErfcQCoefficient = { - 1.00000000000000000000E0, 1.32281951154744992508E1, - 8.67072140885989742329E1, 3.54937778887819891062E2, - 9.75708501743205489753E2, 1.82390916687909736289E3, - 2.24633760818710981792E3, 1.65666309194161350182E3, - 5.57535340817727675546E2}; -std::array kErfcRCoefficient = { - 5.64189583547755073984E-1, 1.27536670759978104416E0, - 5.01905042251180477414E0, 6.16021097993053585195E0, - 7.40974269950448939160E0, 2.97886665372100240670E0}; -std::array kErfcSCoefficient = { - 1.00000000000000000000E0, 2.26052863220117276590E0, - 9.39603524938001434673E0, 1.20489539808096656605E1, - 1.70814450747565897222E1, 9.60896809063285878198E0, - 3.36907645100081516050E0}; -std::array kErfTCoefficient = { - 9.60497373987051638749E0, 9.00260197203842689217E1, - 2.23200534594684319226E3, 7.00332514112805075473E3, - 5.55923013010394962768E4}; -std::array kErfUCoefficient = { - 1.00000000000000000000E0, 3.35617141647503099647E1, - 5.21357949780152679795E2, 4.59432382970980127987E3, - 2.26290000613890934246E4, 4.92673942608635921086E4}; -} // namespace +XlaOp IsNegZero(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegZero", operand)); + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); + + // The bitwise representation of -0 in bfloat16 and IEEE 754 is 0x80...0 + // (sign bit on, all other bits off). + switch (shape.element_type()) { + case F64: + return Eq(BitcastConvertType(operand, U64), + ConstantR0WithType(&b, U64, uint64{1} << 63)); + case F32: + return Eq(BitcastConvertType(operand, U32), + ConstantR0WithType(&b, U32, uint32{1} << 31)); + case F16: + case BF16: + // Not all XLA backends handle U16 well, so we convert to F32/U32. + // TODO(jlebar): It would be nice if we could stay in (B)F16/U16 for + // backends that *do* support it. + return Eq(BitcastConvertType(ConvertElementType(operand, F32), U32), + ConstantR0WithType(&b, U32, uint32{1} << 31)); + default: + LOG(FATAL) << "Expected real fp type."; + } + }); +} + +XlaOp Square(XlaOp operand) { return operand * operand; } + +XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; } // Evaluate the polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. @@ -77,27 +121,86 @@ XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients) { return poly; } -// Compute an approximation of the error function complement (1 - erf(x)). -XlaOp Erfc(XlaOp x) { +// Computes an approximation of the error function complement (1 - erf(x)). +// +// Precondition: abs(x) >= 1. Otherwise, use ErfImpl. +// +// This follows Cephes's f32 implementation of erfc, and so it may have errors +// for double precision. +// +// See also these alternate implementations of erf and erfc: +// +// https://stackoverflow.com/questions/35148198 +// https://stackoverflow.com/questions/35966695 +// +static XlaOp ErfcImpl(XlaOp x) { + // Coefficients for erfc(f32), from Cephes. + // + // erfc(x) = exp(-x^2) P(1/x), 1 < x < 2 + static std::array kErfcPCoefficient{ + +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1, + -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1, + +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1, + }; + // erfc(x) = exp(-x^2) 1/x P(1/x^2), 2 < x < 14 + static std::array kErfcRCoefficient{ + -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0, + +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1, + -2.820767439740514E-1, +5.641895067754075E-1, + }; + XlaOp abs_x = Abs(x); XlaOp z = Exp(-x * x); + XlaOp q = ScalarLike(x, 1) / abs_x; + XlaOp y = q * q; + XlaOp p = Select(Lt(abs_x, ScalarLike(x, 2.0)), + EvaluatePolynomial(y, kErfcPCoefficient), + EvaluatePolynomial(y, kErfcRCoefficient)); + y = z * q * p; + return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y, y); +} - XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient); - XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient); - XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient); - XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient); - - XlaOp y = Select(Lt(abs_x, ScalarLike(x, 8.0)), z * pp / pq, z * pr / ps); +// Compute a polynomial approximation of the error function. +// +// Precondition: abs(x) <= 1. Otherwise, use ErfcImpl. +// +// This follows Cephes's f32 implementation of erf, so it may have errors for +// double precision. +static XlaOp ErfImpl(XlaOp x) { + // Coefficients for by erf(f32), from Cephes. + // + // erf(x) = x P(x^2), 0 < x < 1 + static std::array kErfTCoefficient{ + +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3, + -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1, + +1.128379165726710E+0, + }; + + return x * EvaluatePolynomial(x * x, kErfTCoefficient); +} - return Select(Lt(x, ScalarLike(x, 0.0)), ScalarLike(x, 2.0) - y, y); +XlaOp Erfc(XlaOp x) { + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x)); + // erfc(x) = + // erfc_impl(x) if x > 1 + // 1 - erf_impl(x) otherwise + return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl(x), + ScalarLike(x, 1) - ErfImpl(x)); + }); } -// Compute a polynomial approximation of the error function. XlaOp Erf(XlaOp x) { - XlaOp z = x * x; - XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient); - XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient); - return x * pt / pu; + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erf", x)); + // erf(x) = + // erf_impl(x) if x < 1 + // 1 - erfc_impl(x) otherwise + return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl(x), + ScalarLike(x, 1) - ErfcImpl(x)); + }); } // Approximation for the inverse error function from @@ -113,37 +216,30 @@ XlaOp Erf(XlaOp x) { // } // return p*x XlaOp ErfInv(XlaOp x) { - XlaBuilder* b = x.builder(); - return b->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); - constexpr int kDegree = 9; - constexpr std::array w_less_than_5_constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array w_greater_than_5_constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; + constexpr int kDegree = 9; + constexpr std::array w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; - auto one = ScalarLike(x, 1.0); - auto w = -Log((one - x) * (one + x)); - - auto lt = Lt(w, ScalarLike(x, 5.0)); - auto coefficient = [&](int i) { - return Select(lt, - Broadcast(ScalarLike(x, w_less_than_5_constants[i]), - AsInt64Slice(shape.dimensions())), - Broadcast(ScalarLike(x, w_greater_than_5_constants[i]), - AsInt64Slice(shape.dimensions()))); - }; - w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0)); - auto p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = coefficient(i) + p * w; - } - return p * x; - }); + auto one = ScalarLike(x, 1.0); + auto w = -Log((one - x) * (one + x)); + + auto lt = Lt(w, ScalarLike(x, 5.0)); + auto coefficient = [&](int i) { + return Select(lt, FullLike(x, w_less_than_5_constants[i]), + FullLike(x, w_greater_than_5_constants[i])); + }; + w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0)); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = coefficient(i) + p * w; + } + return p * x; } namespace { @@ -170,49 +266,86 @@ static constexpr std::array kLanczosCoefficients = { // t(z) = z + kLanczosGamma + 1/2 // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) XlaOp Lgamma(XlaOp input) { - XlaOp one_half = ScalarLike(input, 0.5); - XlaOp one = ScalarLike(input, 1); - - XlaOp pi = ScalarLike(input, M_PI); - XlaOp log_pi = ScalarLike(input, std::log(M_PI)); - XlaOp log_sqrt_two_pi = ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2); - - XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); - XlaOp log_lanczos_gamma_plus_one_half = - ScalarLike(input, std::log(kLanczosGamma + 0.5)); - - XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); - - // If the input is less than 0.5 use Gauss's reflection formula: - // gamma(x) = pi / sin(pi * x) * gamma(1 - x) - XlaOp need_to_reflect = Lt(Real(input), one_half); - XlaOp z = Select(need_to_reflect, -input, input - one); - - XlaOp x = base_lanczos_coeff; - for (int i = 0; i < kLanczosCoefficients.size(); ++i) { - XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); - XlaOp index = ScalarLike(input, i); - x = x + lanczos_coefficient / (z + index + one); - } + auto& b = *input.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Lgamma", input)); + + XlaOp one_half = ScalarLike(input, 0.5); + XlaOp one = ScalarLike(input, 1); + + XlaOp pi = ScalarLike(input, M_PI); + XlaOp log_pi = ScalarLike(input, std::log(M_PI)); + XlaOp log_sqrt_two_pi = + ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2); + + XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); + XlaOp log_lanczos_gamma_plus_one_half = + ScalarLike(input, std::log(kLanczosGamma + 0.5)); + + XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); + + // If the input is less than 0.5 use Euler's reflection formula: + // gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) + XlaOp need_to_reflect = Lt(input, one_half); + XlaOp z = Select(need_to_reflect, -input, input - one); + + XlaOp x = base_lanczos_coeff; + for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); + XlaOp index = ScalarLike(input, i); + x = x + lanczos_coefficient / (z + index + one); + } - // To improve accuracy on platforms with less-precise log implementations, - // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on - // the device. - // log(t) = log(kLanczosGamma + 0.5 + z) - // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) - XlaOp t = lanczos_gamma_plus_one_half + z; - XlaOp log_t = - log_lanczos_gamma_plus_one_half + Log1p(z / lanczos_gamma_plus_one_half); - - XlaOp log_y = log_sqrt_two_pi + (z + one_half) * log_t - t + Log(x); - - // If z = a + 0j, the analytic continuation of log reduces to taking the - // absolute value of the real part. - // Re(log(z)) = Re(log|z| + arg(z)j) - // = log|a| - XlaOp reflection = log_pi - Log(Abs(Sin(pi * input))) - log_y; - XlaOp result = Select(need_to_reflect, reflection, log_y); - return result; + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + XlaOp t = lanczos_gamma_plus_one_half + z; + XlaOp log_t = log_lanczos_gamma_plus_one_half + + Log1p(z / lanczos_gamma_plus_one_half); + + XlaOp log_y = log_sqrt_two_pi + (z + one_half) * log_t - t + Log(x); + + // Compute the reflected value, used when x < 0.5: + // + // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))). + // + // (The abs is because lgamma is the log of the absolute value of the gamma + // function.) + // + // We have to be careful when computing the final term above. gamma(x) goes + // to +/-inf at every integer x < 0, and this is controlled by the + // sin(pi * x) term. The slope is large, so precision is particularly + // important. + // + // Because abs(sin(pi * x)) has period 1, we can equivalently use + // abs(sin(pi * frac(x))) = sin(pi * frac(x)), where frac(x) is the + // fractional part of x. This is more numerically accurate: It doesn't + // overflow to inf like pi * x can, and if x is an integer, it evaluates to + // 0 exactly, which is significant because we then take the log of this + // value, and log(0) is inf. + // + // We don't have a frac(x) primitive in XLA and computing it is tricky, but + // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for + // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)). + // + XlaOp abs_input = Abs(input); + XlaOp reflection_denom = Log(Sin(pi * (abs_input - Floor(abs_input)))); + + // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf, + // then it "wins" and the result is +/-inf. + XlaOp reflection = + Select(IsFinite(reflection_denom), log_pi - reflection_denom - log_y, + -reflection_denom); + XlaOp result = Select(need_to_reflect, reflection, log_y); + + // lgamma(+/-inf) = +inf. + XlaOp inf_bcast = FullLike(input, std::numeric_limits::infinity()); + return Select(Or(IsFinite(input), // is finite, or + Not(Or(Lt(input, one), Ge(input, one)))), // is nan + result, inf_bcast); + }); } // Compute the Digamma function using Lanczos' approximation from "A Precision @@ -223,69 +356,84 @@ XlaOp Lgamma(XlaOp input) { // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) // A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) XlaOp Digamma(XlaOp input) { - XlaOp zero = ScalarLike(input, 0); - XlaOp one_half = ScalarLike(input, 0.5); - XlaOp one = ScalarLike(input, 1); - - XlaOp pi = ScalarLike(input, M_PI); - - XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma); - XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); - XlaOp log_lanczos_gamma_plus_one_half = - ScalarLike(input, std::log(kLanczosGamma + 0.5)); - - XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); - - // If the input is less than 0.5 use Gauss's reflection formula: - // digamma(x) = digamma(1 - x) - pi * cot(pi * x) - XlaOp need_to_reflect = Lt(Real(input), one_half); - XlaOp z = Select(need_to_reflect, -input, input - one); - - XlaOp num = zero; - XlaOp denom = base_lanczos_coeff; - for (int i = 0; i < kLanczosCoefficients.size(); ++i) { - XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); - XlaOp index = ScalarLike(input, i); - num = num - lanczos_coefficient / ((z + index + one) * (z + index + one)); - denom = denom + lanczos_coefficient / (z + index + one); - } + auto& b = *input.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input)); + + XlaOp zero = ScalarLike(input, 0); + XlaOp one_half = ScalarLike(input, 0.5); + XlaOp one = ScalarLike(input, 1); + + XlaOp pi = ScalarLike(input, M_PI); + + XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma); + XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); + XlaOp log_lanczos_gamma_plus_one_half = + ScalarLike(input, std::log(kLanczosGamma + 0.5)); + + XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); + + // If the input is less than 0.5 use Euler's reflection formula: + // digamma(x) = digamma(1 - x) - pi * cot(pi * x) + XlaOp need_to_reflect = Lt(input, one_half); + XlaOp z = Select(need_to_reflect, -input, input - one); + + XlaOp num = zero; + XlaOp denom = base_lanczos_coeff; + for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); + XlaOp index = ScalarLike(input, i); + num = num - lanczos_coefficient / ((z + index + one) * (z + index + one)); + denom = denom + lanczos_coefficient / (z + index + one); + } - // To improve accuracy on platforms with less-precise log implementations, - // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on - // the device. - // log(t) = log(kLanczosGamma + 0.5 + z) - // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) - XlaOp t = lanczos_gamma_plus_one_half + z; - XlaOp log_t = - log_lanczos_gamma_plus_one_half + Log1p(z / lanczos_gamma_plus_one_half); - - XlaOp y = log_t + num / denom - lanczos_gamma / t; - XlaOp reflection = y - pi * Cos(pi * input) / Sin(pi * input); - XlaOp result = Select(need_to_reflect, reflection, y); - return result; + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + XlaOp t = lanczos_gamma_plus_one_half + z; + XlaOp log_t = log_lanczos_gamma_plus_one_half + + Log1p(z / lanczos_gamma_plus_one_half); + + XlaOp y = log_t + num / denom - lanczos_gamma / t; + XlaOp reflection = y - pi * Cos(pi * input) / Sin(pi * input); + return Select(need_to_reflect, reflection, y); + }); } // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. XlaOp RoundToEven(XlaOp x) { - auto half = ScalarLike(x, 0.5); - auto one = ScalarLike(x, 1.0); - auto two = ScalarLike(x, 2.0); - - auto round_val = Floor(x); - auto fraction = x - round_val; - auto nearest_even_int = round_val - two * Floor(half * x); - auto is_odd = Eq(nearest_even_int, one); - return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)), - round_val + one, round_val); + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + // Reject non-real non-fp inputs (What does it even mean to round a complex + // number? Do you round each component equally? In that case, you should + // just ask for that explicitly.) + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RoundToEven", x)); + + auto half = ScalarLike(x, 0.5); + auto one = ScalarLike(x, 1.0); + auto two = ScalarLike(x, 2.0); + + auto round_val = Floor(x); + auto fraction = x - round_val; + auto nearest_even_int = round_val - two * Floor(half * x); + auto is_odd = Eq(nearest_even_int, one); + return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)), + round_val + one, round_val); + }); } // Trigonometric functions. -// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) +// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1 +// pi if x == -1 XlaOp Acos(XlaOp x) { - return ScalarLike(x, 2.0) * - Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), ScalarLike(x, 1.0) + x); + return Select(Ne(x, FullLike(x, -1)), + ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), + ScalarLike(x, 1.0) + x), + FullLike(x, M_PI)); } // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) @@ -323,9 +471,88 @@ XlaOp MaybeConjugate(XlaOp x, bool conjugate) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - auto perform_conj = shape.element_type() == C64 && conjugate; + auto perform_conj = + primitive_util::IsComplexType(shape.element_type()) && conjugate; return perform_conj ? Conj(x) : x; }); } +XlaOp NextAfter(XlaOp from, XlaOp to) { + auto builder = from.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(from)); + int bitwidth = primitive_util::BitWidth(shape.element_type()); + auto int_type = primitive_util::UnsignedIntegralTypeForBitWidth(bitwidth); + auto from_as_int = BitcastConvertType(from, int_type); + auto to_as_int = BitcastConvertType(to, int_type); + + // The result is NaN if either "from" or "to" are NaN. + auto from_is_nan = Ne(from, from); + auto to_is_nan = Ne(to, to); + auto nan_input = Or(from_is_nan, to_is_nan); + auto result_for_nan = + Broadcast(ScalarLike(from, std::numeric_limits::quiet_NaN()), + shape.dimensions()); + result_for_nan = BitcastConvertType(result_for_nan, int_type); + + // The sign bit is the MSB. + const int64 sign_mask = int64{1} << (bitwidth - 1); + // Discard the sign bit to make the result non-negative. + auto from_abs = And(from_as_int, ScalarLike(from_as_int, ~sign_mask)); + auto to_abs = And(to_as_int, ScalarLike(to_as_int, ~sign_mask)); + + // When both "from" and "to" are equal, the result is "to". + // N.B. It would not make a difference if we chose the result to be "from". + auto from_and_to_are_equal = Eq(from_as_int, to_as_int); + auto result_for_equal = to_as_int; + + // When both "from" and "to" are both 0, the result is "to". This ensures we + // get a zero signed like "to". + auto from_is_zero = Eq(from_abs, ZerosLike(from_abs)); + auto to_is_zero = Eq(to_abs, ZerosLike(to_abs)); + auto result_for_both_zero = to_as_int; + + auto from_sign = And(from_as_int, ScalarLike(from_as_int, sign_mask)); + auto to_sign = And(to_as_int, ScalarLike(to_as_int, sign_mask)); + + // If from == 0 && to != 0, we need to return the smallest subnormal number + // signed like "to". + auto result_for_from_zero_to_non_zero = + Or(to_sign, ScalarLike(from_as_int, 1)); + + // If the sign of "from" and "to" disagree: + // - we need to make the magnitude of "from" smaller so that it is closer to + // zero. + // + // Otherwise the signs agree: + // - "from" with a magnitude larger than "to" means we need to make the + // magnitude smaller. + // - "from" with a magnitude smaller than "to" means we need to make the + // magnitude larger. + // - "from" with the same magnitude and sign as "to" has already been + // handled. + auto signs_disagree = Ne(from_sign, to_sign); + auto from_magnitude_larger_than_to = Gt(from_abs, to_abs); + auto result_has_smaller_magnitude = + Or(from_magnitude_larger_than_to, signs_disagree); + auto magnitude_adjustment = + Select(result_has_smaller_magnitude, + Broadcast(ScalarLike(from_as_int, -1), shape.dimensions()), + Broadcast(ScalarLike(from_as_int, 1), shape.dimensions())); + auto result = Add(from_as_int, magnitude_adjustment); + // Handle from == ±0. + result = Select(from_is_zero, + Select(to_is_zero, result_for_both_zero, + result_for_from_zero_to_non_zero), + result); + // Handle from == to. + result = Select(from_and_to_are_equal, result_for_equal, result); + // Handle isnan(from) || isnan(to). + result = Select(nan_input, result_for_nan, result); + + // Cast back to the original type. + return BitcastConvertType(result, shape.element_type()); + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 17612bf9fdc0f1eabb338671c93c025c5b268872..71a3acedcec0a8e65561d4139baeaf532ec8bf46 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -20,11 +20,22 @@ limitations under the License. namespace xla { -// Computes the square root of 'operand'. -XlaOp Sqrt(XlaOp operand); - -// Computes the reciprocal of the square root of 'operand'. -XlaOp Rsqrt(XlaOp operand); +// Determines whether operand is +/-inf or nan. +// +// Raises an error if called on integral or complex values. +XlaOp IsPosInf(XlaOp operand); +XlaOp IsNegInf(XlaOp operand); +XlaOp IsInf(XlaOp operand); +XlaOp IsNan(XlaOp operand); + +// Determines whether operand is equal to -0. +// +// Raises an error for integral or complex values. +XlaOp IsNegZero(XlaOp operand); + +// Returns the next number after 'from' in the direction of 'to' the same way +// std::nextafter(from, to) would. +XlaOp NextAfter(XlaOp from, XlaOp to); // Computes the square of 'operand'. XlaOp Square(XlaOp operand); @@ -32,7 +43,7 @@ XlaOp Square(XlaOp operand); // Computes the reciprocal of 'operand'. XlaOp Reciprocal(XlaOp operand); -// Evaluates a polynomial given coefficients and `x`. +// Evaluates a polynomial given coefficients and 'x'. // N.B. Coefficients should be supplied in decreasing order. XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients); @@ -86,7 +97,7 @@ XlaOp Cosh(XlaOp x); // Computes the hyperbolic sine of 'x'. XlaOp Sinh(XlaOp x); -// Applies a complex conjugation operation if `a` is complex and `conjugate` +// Applies a complex conjugation operation if 'a' is complex and 'conjugate' // is true, otherwise returns its argument. xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); diff --git a/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7f423d54dbb7ff911398b0137b482ee47f46c5c1 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc @@ -0,0 +1,188 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +using Eigen::half; + +struct Testcase { + Testcase(string name, const std::function& op, + float (*host_op)(float)) + : name(name), op(op), host_op(host_op) {} + + Testcase& set_tolerance(float abs_err, float rel_err) { + error.abs = abs_err; + error.rel = rel_err; + return *this; + } + + Testcase& set_relaxed_nans() { + error.relaxed_nans = true; + return *this; + } + + Testcase& set_fewer_infs_ok() { + error.fewer_infs_ok = true; + return *this; + } + + Testcase& set_skip_pos_inf() { + skip_pos_inf = true; + return *this; + } + + Testcase& set_skip_neg_inf() { + skip_neg_inf = true; + return *this; + } + + Testcase& set_skip_infs() { + skip_pos_inf = true; + skip_neg_inf = true; + return *this; + } + + Testcase& set_skip_neg_zero() { + skip_neg_zero = true; + return *this; + } + + string name; + std::function op; + float (*host_op)(float); + + ErrorSpec error{0.01, 0.01}; + + // If true, don't test +/-infinity or negative 0. + bool skip_pos_inf = false; + bool skip_neg_inf = false; + bool skip_neg_zero = false; +}; + +void PrintTo(const Testcase& tc, std::ostream* os) { *os << tc.name; } + +class MathExhaustiveTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface { + public: + MathExhaustiveTest() { + // Disable fast-math, otherwise we get the wrong results for e.g. + // sqrt(-inf). + SetFastMathDisabled(true); + } +}; + +// Checks a function's behavior on all fp16 values. +// +// TODO(jlebar): asin and lgamma tests fail on interpreter. +XLA_TEST_P(MathExhaustiveTest, DISABLED_ON_INTERPRETER(F16)) { + const Testcase& tc = GetParam(); + XlaBuilder b(TestName()); + + std::vector input; + for (uint32 i = 0; i < 1 << 16; ++i) { + half h; + h.x = i; + + // If we're not using infinity as an input, use 0 as a placeholder rather + // than simply skipping this element. We do this because when the test + // framework reports an incorrect answer, it tells us which index failed. + // So long as our inputs are a simple list of all possible float16s, we can + // convert an index to a half with e.g. the following Python: + // + // np.frombuffer(array('H', [12345]), dtype=np.float16)[0] + // + // but as soon as our list of inputs has any gaps, this doesn't work. + if (std::isinf(static_cast(h)) && + ((tc.skip_pos_inf && h > half{0}) || + (tc.skip_neg_inf && h < half{0}))) { + h = half{0}; + } + + if (h == half{0} && tc.skip_neg_zero && + std::signbit(static_cast(h))) { + h = half{0}; + } + + input.push_back(h); + } + + std::vector expected_result; + for (const auto& h : input) { + expected_result.push_back( + static_cast(tc.host_op(static_cast(h)))); + } + + XlaOp param = AddParam(LiteralUtil::CreateR1(input), &b); + tc.op(param); + ComputeAndCompareR1(&b, expected_result, {}, tc.error); +} + +// TODO(b/123355973): The following tests from math.cc are missing. +// +// - Many failures. +// +// Testcase{"acosh", Acosh, std::acosh}.set_relaxed_nans(), +// Testcase{"asinh", Asinh, std::asinh}, +// Testcase{"sinh", Sinh, std::sinh}, +// Testcase{"cosh", Cosh, std::cosh}.set_fewer_infs_ok(), +// Testcase{"round_to_even", RoundToEven, +// [](float x) { return std::nearbyint(x / 2) * 2; }}, +// +// - No equivalent std function to compare with. +// +// Testcase{"erfinv", ErfInv, std::erfinv}, +// Testcase{"digamma", Digamma, std::digamma}, +// +// - Needs a special test (function takes two args, and simply computing in f32 +// and downcasting to f16 doesn't give the correct answer). +// +// Testcase{"nextafter", NextAfter, std::nextafter}, +// +// TODO(b/123355973): Test math functions not from math.cc (e.g. log). +// TODO(b/123355973): Test bf16 and f32. +// TODO(b/123355973): Get rid of skip_infs / skip_neg_zero below if possible. +// TODO(b/123355973): Reduce lgamma error if possible; it is very high. +INSTANTIATE_TEST_CASE_P( + MathExhaustiveTest_Instantiation, MathExhaustiveTest, + ::testing::ValuesIn(std::vector{ + Testcase{"sqrt", Sqrt, std::sqrt}.set_skip_neg_inf(), + Testcase{"rsqrt", Rsqrt, [](float x) { return 1 / std::sqrt(x); }} + .set_tolerance(0.05, 0.05) + .set_skip_infs() + .set_skip_neg_zero(), + Testcase{"square", Square, [](float x) { return x * x; }}, + Testcase{"reciprocal", Reciprocal, [](float x) { return 1 / x; }}, + Testcase{"erf", Erf, std::erf}.set_tolerance(0.001, 0.0001), + Testcase{"erfc", Erfc, std::erfc}.set_tolerance(0.001, 0.0001), + Testcase{"lgamma", Lgamma, std::lgamma} + .set_tolerance(0.1, 0.15) + .set_fewer_infs_ok(), + Testcase{"asin", Asin, std::asin}.set_skip_infs(), + Testcase{"acos", Acos, std::acos}.set_skip_infs(), + Testcase{"atan", Atan, std::atan}, + Testcase{"tan", Tan, std::tan}.set_tolerance(0.05, 0.05), + })); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index ae2ea225d1aadd7b3a794eabeca866c498f34760..bdfb0575f573716b54cf9116d155d8a3a55056e8 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -30,6 +31,138 @@ class MathTest : public ClientLibraryTestBase { ErrorSpec error_spec_{0.0001}; }; +// Write TYPED_TESTs within the class definition so that we don't have to litter +// "this->" everywhere. +template +class MathTypedTest : public MathTest { + public: + void TestLogEdgeCases() { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + Log(AddParam(LiteralUtil::CreateR1({T{0.0}, T{-0.0}}), &b)); + ComputeAndCompareR1(&b, + {-std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}, + {}, error_spec_); + } + + void TestLog1pEdgeCases() { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + Log1p(AddParam(LiteralUtil::CreateR1({T{0.0}, T{-0.0}, T{-1.0}}), &b)); + ComputeAndCompareR1( + &b, {T{0.0}, T{-0.0}, -std::numeric_limits::infinity()}, {}, + error_spec_); + } + + void TestIsInfOrNan() { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + auto x = + ConstantR1(&b, { + T{0}, + T{100}, + T{-1000}, + T{std::numeric_limits::max()}, + T{std::numeric_limits::lowest()}, + T{std::numeric_limits::infinity()}, + T{-std::numeric_limits::infinity()}, + T{std::numeric_limits::quiet_NaN()}, + T{std::numeric_limits::signaling_NaN()}, + }); + Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)}); + + auto expected = LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR1( + {true, true, true, true, true, false, false, false, false}), + LiteralUtil::CreateR1( + {false, false, false, false, false, true, true, false, false}), + LiteralUtil::CreateR1( + {false, false, false, false, false, true, false, false, false}), + LiteralUtil::CreateR1( + {false, false, false, false, false, false, true, false, false}), + LiteralUtil::CreateR1( + {false, false, false, false, false, false, false, true, true})); + ComputeAndCompareLiteral(&b, expected, {}); + } + + void TestIsNegZero() { + SetFastMathDisabled(true); + XlaBuilder b(TestName()); + T inf(std::numeric_limits::infinity()); + T nan(std::numeric_limits::quiet_NaN()); + IsNegZero(AddParam( + LiteralUtil::CreateR1({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}), + &b)); + + ComputeAndCompareLiteral( + &b, + LiteralUtil::CreateR1( + {true, false, false, false, false, false, false}), + {}, error_spec_); + } +}; + +// TODO(b/123355973): Add bfloat16 to TestTypes once it's working. +#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +using TestTypes = ::testing::Types; +#else +using TestTypes = ::testing::Types; +#endif + +TYPED_TEST_CASE(MathTypedTest, TestTypes); + +XLA_TYPED_TEST(MathTypedTest, LogEdgeCases) { this->TestLogEdgeCases(); } +XLA_TYPED_TEST(MathTypedTest, Log1pEdgeCases) { this->TestLog1pEdgeCases(); } +XLA_TYPED_TEST(MathTypedTest, IsInfOrNan) { this->TestIsInfOrNan(); } +XLA_TYPED_TEST(MathTypedTest, IsNegZero) { this->TestIsNegZero(); } + +// Check that certain ops only support real, floating-point inputs. +// +// TODO(jlebar): Expand this test to cover more ops. +XLA_TEST_F(MathTest, RealFpOnlyOps) { + for (int64 i = PrimitiveType_MIN; i <= PrimitiveType_MAX; ++i) { + auto ty = static_cast(i); + SCOPED_TRACE(PrimitiveType_Name(ty)); + Shape shape; + if (primitive_util::IsArrayType(ty)) { + shape = ShapeUtil::MakeShape(ty, {42}); + } else if (ty == PrimitiveType::TUPLE) { + shape = ShapeUtil::MakeTupleShape({}); + } else if (ty == PrimitiveType::OPAQUE) { + shape = ShapeUtil::MakeOpaqueShape(); + } else if (ty == PrimitiveType::TOKEN) { + shape = ShapeUtil::MakeTokenShape(); + } else { + continue; + } + + for (const auto& test : + std::vector, string>>({ + {IsFinite, "is_finite"}, + {IsInf, "is_inf"}, + {IsPosInf, "is_pos_inf"}, + {IsNegInf, "is_neg_inf"}, + {IsNan, "is_nan"}, + {Erf, "erf"}, + {Erfc, "erfc"}, + {Lgamma, "lgamma"}, + {Digamma, "digamma"}, + {RoundToEven, "round_to_even"}, + })) { + SCOPED_TRACE(test.second); + XlaBuilder b(TestName()); + XlaOp p = Parameter(&b, 0, shape, "p0"); + test.first(p); + + EXPECT_EQ(b.first_error().ok(), primitive_util::IsFloatingPointType(ty)); + } + } +} + XLA_TEST_F(MathTest, SqrtF32) { XlaBuilder builder(TestName()); Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32); @@ -106,6 +239,27 @@ XLA_TEST_F(MathTest, Lgamma) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, LgammaF16) { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + + // These seemingly arbitrary inputs came from debugging the lgamma + // implementation against a test which tried all possible f16 values. + auto x = ConstantR1(&b, { + half(-7360.0), + half(-4066.0), + half(-5.9605e-08), + }); + Lgamma(x); + std::vector expected = { + std::numeric_limits::infinity(), + std::numeric_limits::infinity(), + half(16.64), + }; + ComputeAndCompareR1(&b, expected, {}, ErrorSpec{0.1}); +} + XLA_TEST_F(MathTest, Digamma) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {1.0, 0.5, 1 / 3.0, 0.25, 1 / 6.0, 0.125, @@ -148,5 +302,40 @@ XLA_TEST_F(MathTest, RoundToEven) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, ErfRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Erf(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, ErfcRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Erfc(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, LgammaRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Lgamma(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, DigammaRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Digamma(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, RoundToEvenRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + RoundToEven(x); + EXPECT_FALSE(b.Build().status().ok()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index ffd744d190885b8e3f4149a48a706498b3787618..a055a8e625c680cf5232896c95cd35b78cb172bc 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -15,40 +15,52 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n) { - auto a = Iota(builder, type, m); - auto b = Iota(builder, type, n); + auto a = Iota(builder, U32, m); + auto b = Iota(builder, U32, n); auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0}); return ConvertElementType(indicator, type); } -XlaOp GetMatrixDiagonal(XlaOp x) { +XlaOp GetMatrixDiagonal(XlaOp x, int k) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); + + auto offset = ConstantR0WithType(builder, S32, k); + absl::Span major_dims = AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); + auto a = Iota(builder, S32, n); + auto b = Iota(builder, S32, m) + offset; auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); auto mask = Broadcast(indicator, major_dims); @@ -58,111 +70,269 @@ XlaOp GetMatrixDiagonal(XlaOp x) { primitive_util::IsIntegralType(shape.element_type()) ? CreateScalarOrComputation(shape.element_type(), builder) : CreateScalarAddComputation(shape.element_type(), builder); - - return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), - reducer, {m >= n ? n_dims - 2 : n_dims - 1}); + // k == 0, we can save one slice op. + if (k == 0) { + return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), + reducer, {m >= n ? n_dims - 2 : n_dims - 1}); + } else if (k > 0) { + auto result = Reduce(Select(mask, x, Zeros(builder, shape)), + ScalarLike(x, 0), reducer, {n_dims - 2}); + return SliceInMinorDims(result, {std::min(k, n)}, + {std::min(m + k, n)}); + } else { + auto result = Reduce(Select(mask, x, Zeros(builder, shape)), + ScalarLike(x, 0), reducer, {n_dims - 1}); + return SliceInMinorDims(result, {std::min(-k, m)}, + {std::min(m, n - k)}); + } }); } -XlaOp Triangle(XlaOp x, bool lower) { +XlaOp TriangleMask(XlaOp x, int diagonal) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); absl::Span major_dims = AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); + auto a = Iota(builder, S32, n); + auto b = Iota(builder, S32, m) + ConstantR0(builder, diagonal); XlaOp indicator; - if (lower) { - indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } else { - indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } - auto mask = Broadcast(indicator, major_dims); - - return Select(mask, x, Zeros(builder, shape)); + indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + return Broadcast(indicator, major_dims); }); } +XlaOp Triangle(XlaOp x, bool lower) { + return lower ? Select(TriangleMask(x, 0), x, ZerosLike(x)) + : Select(TriangleMask(x, -1), ZerosLike(x), x); +} + XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } -XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) { +Status ValidateEinsumNumericDimensions(absl::Span x_config, + absl::Span y_config, + absl::Span output_config) { + for (auto dim : output_config) { + if (absl::c_linear_search(x_config, dim) || + absl::c_linear_search(y_config, dim)) { + if (absl::c_count(output_config, dim) > 1) { + return InvalidArgument("Einsum has repeated output dimension."); + } + continue; + } + return InvalidArgument( + "Einsum has output dimension without corresponding input dimension."); + } + for (auto dim : x_config) { + if (absl::c_linear_search(y_config, dim) || + absl::c_linear_search(output_config, dim)) { + if (absl::c_count(x_config, dim) > 1) { + return InvalidArgument("Einsum has repeated lhs dimension."); + } + continue; + } + return InvalidArgument( + "Einsum has lhs dimension without corresponding rhs or output " + "dimension."); + } + for (auto dim : y_config) { + if (absl::c_linear_search(x_config, dim) || + absl::c_linear_search(output_config, dim)) { + if (absl::c_count(y_config, dim) > 1) { + return InvalidArgument("Einsum has repeated rhs dimension."); + } + continue; + } + return InvalidArgument( + "Einsum has rhs dimension without corresponding lhs or output " + "dimension."); + } + return Status::OK(); +} + +xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, + absl::Span y_config, + absl::Span output_config, + xla::PrecisionConfig::Precision precision) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); + TF_RETURN_IF_ERROR( + ValidateEinsumNumericDimensions(x_config, y_config, output_config)); + const int64 x_rank = x_config.size(); + const int64 y_rank = y_config.size(); + const int64 output_rank = output_config.size(); + absl::flat_hash_set x_map; + absl::flat_hash_set y_map; + absl::flat_hash_set output_map; + + auto find = [&](const absl::flat_hash_set& map, int64 d) { + return map.count(d) != 0; + }; - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - if (ShapeUtil::Rank(x_shape) != ShapeUtil::Rank(y_shape)) { - return InvalidArgument( - "Arguments to BatchDot have different ranks: %s vs. %s", - ShapeUtil::HumanString(x_shape), ShapeUtil::HumanString(y_shape)); + auto insert = [&](absl::flat_hash_set& map, char d) { + CHECK(!find(map, d)); + map.insert(d); + }; + + for (auto d : x_config) { + insert(x_map, d); } - const int ndims = ShapeUtil::Rank(x_shape); - if (ndims < 2) { - return InvalidArgument( - "Arguments to BatchDot must have rank >= 2: got %d", ndims); + + for (auto d : y_config) { + insert(y_map, d); } - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector batch_dimension_numbers; - for (int i = 0; i < ndims - 2; ++i) { - if (x_shape.dimensions(i) != y_shape.dimensions(i)) { - return InvalidArgument( - "Dimension %d of inputs to BatchDot must be equal: shapes %s vs %s", - i, ShapeUtil::HumanString(x_shape), - ShapeUtil::HumanString(y_shape)); - } - batch_dimension_numbers.push_back(i); + for (auto d : output_config) { + insert(output_map, d); } - int x_inner_dim = ndims - 1; - int y_inner_dim = ndims - 2; - if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { - return InvalidArgument( - "Dimensions %d and %d of arguments to BatchDot must be equal: " - "shapes %s vs %s", - x_inner_dim, y_inner_dim, ShapeUtil::HumanString(x_shape), - ShapeUtil::HumanString(y_shape)); + DotDimensionNumbers dnums; + std::vector lhs_outer_dims; + auto is_batch_dim = [&](int64 d) { + return find(x_map, d) && find(y_map, d) && find(output_map, d); + }; + auto is_contracting = [&](int64 d) { + return find(x_map, d) && find(y_map, d); + }; + auto rhs_dimension_number = [&](int64 d) { + return absl::c_find(y_config, d) - y_config.begin(); + }; + for (int64 i = 0; i < x_rank; ++i) { + auto dim_name = x_config[i]; + if (is_batch_dim(dim_name)) { + dnums.add_lhs_batch_dimensions(i); + dnums.add_rhs_batch_dimensions(rhs_dimension_number(dim_name)); + } else if (is_contracting(dim_name)) { + dnums.add_lhs_contracting_dimensions(i); + dnums.add_rhs_contracting_dimensions(rhs_dimension_number(dim_name)); + } else { + lhs_outer_dims.push_back(i); + } } - // Check for zero lhs/rhs dim size. - if (ShapeUtil::IsZeroElementArray(x_shape) || - ShapeUtil::IsZeroElementArray(y_shape)) { - std::vector dimensions(batch_dimension_numbers.size()); - for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + std::vector rhs_outer_dims; + for (int64 i = 0; i < y_rank; ++i) { + auto dim_name = y_config[i]; + if (!is_batch_dim(dim_name) && !is_contracting(dim_name)) { + rhs_outer_dims.push_back(i); } - int x_outer_dim = ndims - 2; - int y_outer_dim = ndims - 1; - dimensions.push_back(x_shape.dimensions(x_outer_dim)); - dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return Broadcast( - ConstantLiteral(builder, LiteralUtil::Zero(x_shape.element_type())), - dimensions); + } + + auto output_dimension_number = [&](char d) { + return absl::c_find(output_config, d) - output_config.begin(); + }; + + std::vector output_dims; + output_dims.reserve(output_rank); + for (auto d : dnums.lhs_batch_dimensions()) { + output_dims.push_back(output_dimension_number(x_config[d])); + } + for (auto d : lhs_outer_dims) { + output_dims.push_back(output_dimension_number(x_config[d])); + } + for (auto d : rhs_outer_dims) { + output_dims.push_back(output_dimension_number(y_config[d])); + } + + std::vector transpose_dims(output_rank); + for (int64 i = 0; i < output_rank; ++i) { + transpose_dims[output_dims[i]] = i; } PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); + return Transpose(DotGeneral(x, y, dnums, &precision_proto), transpose_dims); + }); +} + +XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); - dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); - for (auto batch_dimension_number : batch_dimension_numbers) { - dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); - dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + const int ndims = x_shape.rank(); + batch_dimension_numbers.reserve(ndims - 2); + for (int i = 0; i < ndims - 2; ++i) { + batch_dimension_numbers.push_back(i); + } + std::vector x_config = batch_dimension_numbers; + x_config.push_back(ndims - 2); + x_config.push_back(ndims); + std::vector y_config = batch_dimension_numbers; + y_config.push_back(ndims); + y_config.push_back(ndims - 1); + std::vector output_config = batch_dimension_numbers; + output_config.push_back(ndims - 2); + output_config.push_back(ndims - 1); + return Einsum(x, x_config, y, y_config, output_config, precision); + }); +} + +StatusOr, 3>> ParseEinsumString( + absl::string_view einsum_config) { + std::array, 3> einsum_config_numeric; + std::vector main_split = + absl::StrSplit(einsum_config, ','); + + if (main_split.size() != 2) { + return InvalidArgument("Expected one \",\" in einsum_config."); + } + + auto maybe_invalid_character = [](char d) { + if (absl::ascii_isalpha(d)) { + return Status::OK(); } + if (d == '.') { + return InvalidArgument("Unsupported \"...\" or \".\" in einsum config."); + } + return InvalidArgument("Unexpected character in einsum config."); + }; + + auto& x_config = einsum_config_numeric[0]; + x_config.reserve(main_split[0].size()); + for (auto d : main_split[0]) { + TF_RETURN_IF_ERROR(maybe_invalid_character(d)); + x_config.push_back(static_cast(d)); + } + std::vector y_output_split = + absl::StrSplit(main_split[1], "->"); + if (y_output_split.size() != 2) { + return InvalidArgument("Expected one \"->\" in einsum_config."); + } + auto& y_config = einsum_config_numeric[1]; + y_config.reserve(y_output_split[0].size()); + for (auto d : y_output_split[0]) { + TF_RETURN_IF_ERROR(maybe_invalid_character(d)); + y_config.push_back(static_cast(d)); + } + auto& output_config = einsum_config_numeric[2]; + output_config.reserve(y_output_split[1].size()); + for (auto d : y_output_split[1]) { + TF_RETURN_IF_ERROR(maybe_invalid_character(d)); + output_config.push_back(static_cast(d)); + } + return einsum_config_numeric; +} - return DotGeneral(x, y, dot_dnums, &precision_proto); +XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto einsum_config_numeric, + ParseEinsumString(einsum_config)); + return Einsum(x, einsum_config_numeric[0], y, einsum_config_numeric[1], + einsum_config_numeric[2], precision); }); } @@ -170,7 +340,7 @@ XlaOp TransposeInMinorDims(XlaOp x) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); TF_RET_CHECK(n_dims >= 2); std::vector permutation(n_dims); std::iota(permutation.begin(), permutation.end(), 0); @@ -182,4 +352,5 @@ XlaOp TransposeInMinorDims(XlaOp x) { XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) { return transpose ? TransposeInMinorDims(x) : x; } + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index 8856f99c7a0fee8f315aac11fab392cf5536f57b..60c41ec45a086726086dac7227fc432a9c62d0c8 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -16,7 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -26,10 +30,19 @@ namespace xla { // else. XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); -// Get the diagonals of the last two dimensions. If 'x' has shape -// [..., M, N], then the output has shape [..., min(M, N)], containing the -// diagonal elements (i.e., with indices [..., i, i]). -XlaOp GetMatrixDiagonal(XlaOp x); +// Get the diagonals of the last two dimensions. Use k>0 for diagonals above the +// main diagonal, and k<0 for diagonals below the main diagonal. +// +// If 'x' has shape [..., M, N] +// If k >= 0: then the output has shape [..., min(M, N - k)], containing the +// diagonal elements (i.e., with indices [..., i, i + k]). +// If k < 0: then the output has shape [..., min(M + k, N)], containing the +// diagonal elements (i.e., with indices [..., i - k, i]). +XlaOp GetMatrixDiagonal(XlaOp x, int k = 0); + +// Returns a lower-triangular mask, i.e., true below the `diagonal`-th diagonal +// and false above that diagonal. +XlaOp TriangleMask(XlaOp x, int diagonal); // Get the upper or lower triangle part of the last two dimensions XlaOp Triangle(XlaOp x, bool lower); @@ -61,6 +74,40 @@ xla::XlaOp BatchDot( xla::XlaOp x, xla::XlaOp y, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); +// Parse an einsum string into dimension numbers: +// "ab,cb->ac" +// becomes: +// {{0, 1},{2, 1},{0, 2}} +// +// NOTE: This function is meant for testing, there is no need to call it +// directly. + +StatusOr, 3>> ParseEinsumString( + absl::string_view einsum_config); + +// Determine if each dimension label is in at least two inputs. +// +// NOTE: This function is meant for testing, there is no need to call it +// directly. +Status ValidateEinsumNumericDimensions(absl::Span x_config, + absl::Span y_config, + absl::Span output_config); + +// Supports two operand einsum notation like "ab,cb->ac". +xla::XlaOp Einsum( + xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + +// Same as above but supporting numeric labels on dimensins. So "ab,cb->ac" +// becomes: +// x_config = {0, 1} +// y_config = {2, 1} +// output_config = {0, 2} +xla::XlaOp Einsum( + xla::XlaOp x, absl::Span x_config, xla::XlaOp y, + absl::Span y_config, absl::Span output_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + // Transposes a stack of matrices `x` by swapping the last two dimensions. xla::XlaOp TransposeInMinorDims(xla::XlaOp x); diff --git a/tensorflow/compiler/xla/client/lib/matrix_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc index 0593a7517ac125ca8dc5395cee76f6bc23232cd3..a93fc2ccb92912a10b9b6c2192b81cd73566f2a0 100644 --- a/tensorflow/compiler/xla/client/lib/matrix_test.cc +++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc @@ -15,13 +15,15 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { @@ -51,13 +53,24 @@ void MatrixTest::TestMatrixDiagonal() { XlaBuilder builder("GetMatrixDiagonal"); Array3D input(2, 3, 4); input.FillIota(0); - - XlaOp a; - auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); - GetMatrixDiagonal(a); - Array2D expected({{0, 5, 10}, {12, 17, 22}}); - - ComputeAndCompareR2(&builder, expected, {a_data.get()}); + std::map> k_and_expected = { + {0, {{0, 5, 10}, {12, 17, 22}}}, + {1, {{1, 6, 11}, {13, 18, 23}}}, + {2, {{2, 7}, {14, 19}}}, + {3, {{3}, {15}}}, + {4, {{}, {}}}, + {-1, {{4, 9}, {16, 21}}}, + {-2, {{8}, {20}}}, + {-3, {{}, {}}}, + {-4, {{}, {}}}, + }; + for (const auto& kv : k_and_expected) { + XlaOp a; + auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); + GetMatrixDiagonal(a, kv.first); + + ComputeAndCompareR2(&builder, kv.second, {a_data.get()}); + } } XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } @@ -101,5 +114,78 @@ XLA_TEST_F(MatrixTest, RowBatchDot) { ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, {a_data.get(), row_data.get(), index_data.get()}); } + +XLA_TEST_F(MatrixTest, Einsum) { + XlaBuilder builder(TestName()); + + int n = 4; + + XlaOp a, row, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, + "row", &builder, &row); + // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). + auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); + + auto l_index = DynamicSliceInMinorDims( + a, {index, ConstantR0(&builder, 0)}, {1, n}); + Einsum(l_index, row, "abc,adc->abd"); + + ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, + {a_data.get(), row_data.get(), index_data.get()}); +} + +XLA_TEST_F(MatrixTest, ParseEinsumString) { + auto to_vec = [](absl::string_view s) { + std::vector v; + v.reserve(s.size()); + for (auto c : s) { + v.push_back(int64{c}); + } + return v; + }; + + auto to_string = [&](absl::string_view x, absl::string_view y, + absl::string_view o) { + return absl::StrCat(x, ",", y, "->", o); + }; + + std::vector> good_test_cases = {{"ab", "bc", "ac"}, + {"Bab", "Bbc", "Bac"}, + {"ab", "cd", "dcba"}, + {"abc", "abd", "cbd"}}; + for (auto test_case : good_test_cases) { + auto parse_result_or_status = + ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2])); + EXPECT_TRUE(parse_result_or_status.status().ok()); + auto parse_result = parse_result_or_status.ValueOrDie(); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(parse_result[i], to_vec(test_case[i])); + } + EXPECT_TRUE(ValidateEinsumNumericDimensions( + parse_result[0], parse_result[1], parse_result[2]) + .ok()); + } + + std::vector einsum_strings_that_fail_parsing = { + "", "a", "ab->ba", "ab,bc,cd->ad", "a...b,bc->a...c"}; + for (auto test_case : einsum_strings_that_fail_parsing) { + auto parse_result_or_status = ParseEinsumString(test_case); + EXPECT_FALSE(parse_result_or_status.status().ok()); + } + + std::vector einsum_strings_that_fail_numeric_validation = { + "a,b->c", "ab,bc->acd", "abz,bc->ac", "ab,bcz->ac"}; + for (auto test_case : einsum_strings_that_fail_numeric_validation) { + auto parse_result_or_status = ParseEinsumString(test_case); + EXPECT_TRUE(parse_result_or_status.status().ok()); + auto parse_result = parse_result_or_status.ValueOrDie(); + EXPECT_FALSE(ValidateEinsumNumericDimensions( + parse_result[0], parse_result[1], parse_result[2]) + .ok()); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/xla/client/lib/qr.cc similarity index 62% rename from tensorflow/compiler/tf2xla/lib/qr.cc rename to tensorflow/compiler/xla/client/lib/qr.cc index d6007748609fdd161cb89692a167eb7ed12fe00c..640412ec8bcffd2565b11ba25b87f6bf6438d848 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/xla/client/lib/qr.cc @@ -13,15 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/qr.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" @@ -32,10 +31,18 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/errors.h" -namespace tensorflow { +namespace xla { namespace { +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); + return output; +} + // Computes a Householder reflection of the form: // H = I - tau v v.T. // such that @@ -65,52 +72,47 @@ namespace { // return (v, tau, beta) // TODO(phawkins): LAPACK's xLARFG implementation has code for handling // overflows in the norm/beta calculations. Perhaps do the same here. -xla::Status House(xla::XlaOp x, xla::XlaOp k, - absl::Span batch_dims, const int64 m, - xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) { - xla::XlaBuilder* const builder = x.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - const xla::PrimitiveType type = x_shape.element_type(); +Status House(XlaOp x, XlaOp k, absl::Span batch_dims, + const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) { + XlaBuilder* const builder = x.builder(); + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + const PrimitiveType type = x_shape.element_type(); std::vector batch_dim_ids(batch_dims.size()); std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0); const int64 minor_dim = batch_dims.size(); - xla::XlaOp zero = xla::ScalarLike(x, 0.0); - xla::XlaOp one = xla::ScalarLike(x, 1.0); + XlaOp zero = ScalarLike(x, 0.0); + XlaOp one = ScalarLike(x, 1.0); // alpha = x[k] - xla::XlaOp alpha = - xla::Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); + XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); // Compute x[k+1:] (padded with zeros in elements 0..k) - xla::XlaOp iota = xla::Iota(builder, xla::S32, m); - xla::XlaOp x_after_k = - xla::Mul(x, xla::ConvertElementType(xla::Gt(iota, k), type), - /*broadcast_dimensions=*/{minor_dim}); + XlaOp iota = Iota(builder, S32, m); + XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type), + /*broadcast_dimensions=*/{minor_dim}); // sigma = np.dot(x[k+1:], x[k+1:]) - auto sigma = - xla::Reduce(x_after_k * x_after_k, zero, - xla::CreateScalarAddComputation(type, builder), {minor_dim}); + auto sigma = Reduce(x_after_k * x_after_k, zero, + CreateScalarAddComputation(type, builder), {minor_dim}); // mu = np.sqrt(x[k]*x[k] + sigma) - auto mu = xla::Sqrt(xla::Square(alpha) + sigma); + auto mu = Sqrt(Square(alpha) + sigma); - auto sigma_is_zero = xla::Eq(sigma, zero); + auto sigma_is_zero = Eq(sigma, zero); - *beta = xla::Select(sigma_is_zero, alpha, -xla::Sign(alpha) * mu); - *tau = xla::Select(sigma_is_zero, xla::Broadcast(zero, batch_dims), - (*beta - alpha) / *beta); - auto divisor = xla::Select(sigma_is_zero, xla::Broadcast(one, batch_dims), - alpha - *beta); + *beta = Select(sigma_is_zero, alpha, -Sign(alpha) * mu); + *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims), + (*beta - alpha) / *beta); + auto divisor = + Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta); - auto e_k = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, k), type), - std::vector(batch_dims.size(), 1)); + auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type), + std::vector(batch_dims.size(), 1)); // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. - *v = e_k + - xla::Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); + *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); return Status::OK(); } @@ -143,90 +145,86 @@ xla::Status House(xla::XlaOp x, xla::XlaOp k, // return (q, vs, taus) struct QRBlockResult { // The factored R value - xla::XlaOp r; + XlaOp r; // Representation of the Householder matrices I - beta v v.T - xla::XlaOp taus; // Shape: [..., n] - xla::XlaOp vs; // Shape: [..., m, n] + XlaOp taus; // Shape: [..., n] + XlaOp vs; // Shape: [..., m, n] }; -xla::StatusOr QRBlock( - xla::XlaOp a, xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int num_dims = xla::ShapeUtil::Rank(a_shape); +StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int num_dims = a_shape.rank(); if (num_dims < 2) { - return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", - num_dims); + return InvalidArgument("Argument to QR must have rank >= 2; got shape %s", + a_shape.ToString()); } - xla::PrimitiveType type = a_shape.element_type(); + PrimitiveType type = a_shape.element_type(); - const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); const int64 num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); } std::vector batch_dim_indices(num_batch_dims); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); - auto qr_body_fn = - [&](xla::XlaOp j, absl::Span values, - xla::XlaBuilder* builder) -> xla::StatusOr> { + auto qr_body_fn = [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { auto a = values[0]; auto vs = values[1]; auto taus = values[2]; // v, beta = house(a[:, j], j) auto x = DynamicSliceInMinorDims(a, {j}, {1}); - xla::XlaOp v, tau, beta; - TF_RETURN_IF_ERROR(House(xla::Collapse(x, {num_dims - 2, num_dims - 1}), j, + XlaOp v, tau, beta; + TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j, batch_dims, m, &v, &tau, &beta)); std::vector shape = batch_dims; shape.push_back(1); shape.push_back(m); - auto v_broadcast = xla::Reshape(v, shape); + auto v_broadcast = Reshape(v, shape); // a[:, :] -= tau * np.dot(v[:, np.newaxis], // np.dot(v[np.newaxis, :], a[:, :])) auto vva = BatchDot(v_broadcast, a, precision); vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision); - a = a - xla::Mul(tau, vva, - /*broadcast_dimensions=*/batch_dim_indices); + a = a - Mul(tau, vva, + /*broadcast_dimensions=*/batch_dim_indices); // It is more precise to populate column 'k' explicitly, rather than // computing it implicitly by applying the Householder transformation. // a[k,k] = beta // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype) - auto iota = xla::Reshape(xla::Iota(a.builder(), xla::S32, m), {m, 1}); - auto predecessor_mask = xla::ConvertElementType(xla::Lt(iota, j), type); - auto mask = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, j), type), - std::vector(batch_dims.size(), 1)); - auto new_x = - xla::Mul(x, predecessor_mask, - /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + - xla::Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); + auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1}); + auto predecessor_mask = ConvertElementType(Lt(iota, j), type); + auto mask = Broadcast(ConvertElementType(Eq(iota, j), type), + std::vector(batch_dims.size(), 1)); + auto new_x = Mul(x, predecessor_mask, + /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + + Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); a = DynamicUpdateSliceInMinorDims(a, new_x, {j}); // vs[:, j] = v vs = DynamicUpdateSliceInMinorDims( - vs, xla::Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j}); + vs, Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j}); // taus[j] = tau taus = DynamicUpdateSliceInMinorDims( - taus, xla::Reshape(tau, ConcatVectors(batch_dims, {1})), {j}); - return std::vector{a, vs, taus}; + taus, Reshape(tau, ConcatVectors(batch_dims, {1})), {j}); + return std::vector{a, vs, taus}; }; - auto vs = xla::Zeros(builder, xla::ShapeUtil::MakeShape( - type, ConcatVectors(batch_dims, {m, n}))); - auto taus = xla::Zeros( - builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); + auto vs = Zeros( + builder, ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); + auto taus = Zeros(builder, + ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); - TF_ASSIGN_OR_RETURN(auto values, - XlaForEachIndex(std::min(m, n), xla::S32, qr_body_fn, - {a, vs, taus}, "qr", builder)); + TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn, + {a, vs, taus}, "qr", builder)); QRBlockResult result; result.r = values[0]; @@ -250,24 +248,23 @@ xla::StatusOr QRBlock( // return W // There is no need to return Y since at termination of the loop it is equal to // vs. -xla::StatusOr ComputeWYRepresentation( - xla::PrimitiveType type, absl::Span batch_dims, xla::XlaOp vs, - xla::XlaOp taus, int64 m, int64 n, - xla::PrecisionConfig::Precision precision) { +StatusOr ComputeWYRepresentation(PrimitiveType type, + absl::Span batch_dims, + XlaOp vs, XlaOp taus, int64 m, int64 n, + PrecisionConfig::Precision precision) { std::vector batch_dim_indices(batch_dims.size()); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); int64 n_index = batch_dims.size() + 1; - auto body_fn = - [&](xla::XlaOp j, absl::Span values, - xla::XlaBuilder* builder) -> xla::StatusOr> { + auto body_fn = [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { auto w = values[0]; auto y = values[1]; const auto vs = values[2]; const auto taus = values[3]; // Want j values in range [1, ... n). - j = j + xla::ConstantR0(builder, 1); + j = j + ConstantR0(builder, 1); // vs has shape [..., m, 1] auto v = DynamicSliceInMinorDims(vs, {j}, {1}); // beta has shape [..., 1] @@ -278,31 +275,31 @@ xla::StatusOr ComputeWYRepresentation( // wyv has shape [..., m, 1] auto wyv = BatchDot(w, yv, precision); - auto z = xla::Mul( + auto z = Mul( -beta, v + wyv, /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); w = DynamicUpdateSliceInMinorDims(w, z, {j}); y = DynamicUpdateSliceInMinorDims(y, v, {j}); - return std::vector{w, y, vs, taus}; + return std::vector{w, y, vs, taus}; }; - xla::XlaBuilder* builder = vs.builder(); - auto w = xla::Zeros(builder, xla::ShapeUtil::MakeShape( - type, ConcatVectors(batch_dims, {m, n}))); + XlaBuilder* builder = vs.builder(); + auto w = Zeros(builder, + ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); auto y = w; auto v = SliceInMinorDims(vs, {0}, {1}); auto beta = SliceInMinorDims(taus, {0}, {1}); y = UpdateSliceInMinorDims(y, v, {0}); - auto bv = xla::Mul( - -beta, v, - /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); + auto bv = + Mul(-beta, v, + /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); w = UpdateSliceInMinorDims(w, bv, {0}); TF_ASSIGN_OR_RETURN( - auto values, XlaForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus}, - "wy", builder)); + auto values, + ForEachIndex(n - 1, S32, body_fn, {w, y, vs, taus}, "wy", builder)); return values[0]; } @@ -323,34 +320,34 @@ xla::StatusOr ComputeWYRepresentation( // return (q, a) // TODO(phawkins): consider using UT transformations (in the form I - V U V') // rather than WY transformations. -xla::StatusOr QRDecomposition( - xla::XlaOp a, bool full_matrices, int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int num_dims = xla::ShapeUtil::Rank(a_shape); +StatusOr QRDecomposition( + XlaOp a, bool full_matrices, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int num_dims = a_shape.rank(); if (num_dims < 2) { - return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", - num_dims); + return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s", + a_shape.ToString()); } - xla::PrimitiveType type = a_shape.element_type(); + PrimitiveType type = a_shape.element_type(); - const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); const int64 p = std::min(m, n); if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to QR must be >= 1; got ", block_size); + return InvalidArgument("block_size argument to QR must be >= 1; got %d", + block_size); } const int64 num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); } - auto q = xla::Broadcast(xla::IdentityMatrix(builder, type, m, m), batch_dims); + auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); for (int64 i = 0; i < p; i += block_size) { int64 k = std::min(block_size, p - i); @@ -393,4 +390,4 @@ xla::StatusOr QRDecomposition( return result; } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/xla/client/lib/qr.h similarity index 74% rename from tensorflow/compiler/tf2xla/lib/qr.h rename to tensorflow/compiler/xla/client/lib/qr.h index 24b537ac8b63b93e734c3d0e335ea455f7d51a54..827c8eeca05ef09a0d77363eb3c40961b95813d8 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/xla/client/lib/qr.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Computes the QR decompositions of a batch of matrices. That is, // given a (batched) matrix a, computes an orthonormal matrix Q and an @@ -29,14 +29,14 @@ namespace tensorflow { // the block size to use. // TODO(phawkins): handle the complex case. struct QRDecompositionResult { - xla::XlaOp q; - xla::XlaOp r; + XlaOp q; + XlaOp r; }; -xla::StatusOr QRDecomposition( - xla::XlaOp a, bool full_matrices, int64 block_size = 128, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); +StatusOr QRDecomposition( + XlaOp a, bool full_matrices, int64 block_size = 128, + PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b27d364b62444d6d5fb1278b6e6461affc15b2e6 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/qr_test.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/qr.h" + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { + +using QrTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(QrTest, Simple) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2)); + + // Verifies that the decomposition composes back to the original matrix. + // + // This isn't a terribly demanding test, (e.g., we should verify that Q is + // orthonormal and R is upper-triangular) but it's awkward to write such tests + // without more linear algebra libraries. It's easier to test the numerics + // from Python, anyway, where we have access to numpy and scipy. + xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST); + + ComputeAndCompareR2(&builder, a_vals, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(QrTest, SimpleBatched) { + xla::XlaBuilder builder(TestName()); + + xla::Array3D a_vals({ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }, + }); + + xla::XlaOp a; + auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2)); + + xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST); + + ComputeAndCompareR3(&builder, a_vals, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +} // namespace diff --git a/tensorflow/compiler/xla/client/lib/quantize.h b/tensorflow/compiler/xla/client/lib/quantize.h new file mode 100644 index 0000000000000000000000000000000000000000..26dbbd5b00bd1a29f4047c9a4294fcac7340cf6c --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/quantize.h @@ -0,0 +1,186 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/bfloat16/bfloat16.h" + +namespace xla { + +constexpr int64 kBitsOfByte = 8; + +// Represents the range used for quantization +struct QuantizedRange { + QuantizedRange() = default; + QuantizedRange(float min_in, float max_in) : min(min_in), max(max_in) {} + + bool operator==(const QuantizedRange& rhs) const { + return this->min == rhs.min && this->max == rhs.max; + } + + bool operator!=(const QuantizedRange& rhs) const { return !(*this == rhs); } + + tensorflow::bfloat16 min = tensorflow::bfloat16(0.0f); + tensorflow::bfloat16 max = tensorflow::bfloat16(0.0f); +}; + +template +inline std::vector PackToUint32(absl::Span input) { + const int64 kElementsPerPack = sizeof(uint32) / sizeof(T); + const int64 input_size = input.size(); + const int64 output_size = CeilOfRatio(input_size, kElementsPerPack); + + std::vector output_vec; + constexpr int64 kShiftBits = sizeof(T) / sizeof(uint8) * kBitsOfByte; + + for (int64 i = 0; i < output_size; i++) { + uint32 result = 0; + for (int64 p = 0; p < kElementsPerPack; p++) { + int64 index = i * kElementsPerPack + p; + if (index < input_size) { + int64 total_shift_bits = kShiftBits * (kElementsPerPack - p - 1); + result |= (input[index] << total_shift_bits); + } + } + output_vec.push_back(result); + } + + return output_vec; +} + +// Dequantize the quantized input of packed uint32 to bfloat16. +// Only uint8 or uint16 is supported for the original unpacked input. +// Returns a tensor of shape [d0,..., dn * unpack_size] if +// input shape is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T). +// If transpose_output is true, will return a tensor of shape +// [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster when +// input's rank higher than 1. The input needs to be transposed to use +// transpose_output feature. +template +inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range, + absl::string_view mode_string = "MIN_COMBINED", + bool transpose_output = false) { + XlaBuilder* const builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max()) - + std::numeric_limits::min() + 1) / + 2.0f; + const int64 unpack_size = sizeof(uint32) / sizeof(T); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(input)); + + auto element_type = shape.element_type(); + if (element_type != U32) { + return InvalidArgument( + "Only U32 is supported for input type of xla::Dequantize Op."); + } + + // Broadcast the input to [unpack_size, d0, ..., dn] if input size is + // [d0, ..., dn]. + auto broadcast_input = Broadcast(input, {unpack_size}); + + XlaOp iota_r1 = Iota(builder, U32, unpack_size); + // Highest significant bytes needs to shift more bytes than lower + // significant bytes. + XlaOp shift_bytes = + xla::ConstantR0(builder, unpack_size - 1) - iota_r1; + + const int bytes_of_type = sizeof(T) / sizeof(uint8); + std::vector shift_vec(unpack_size, kBitsOfByte * bytes_of_type); + XlaOp shift_bits = + shift_bytes * xla::ConstantR1(builder, shift_vec); + + // Make bit_mask for different data type T. + uint32 bit_mask = 0x00000000; + for (int i = 0; i < bytes_of_type; i++) { + bit_mask <<= kBitsOfByte; + bit_mask |= 0x000000ff; + } + + std::vector shift_transpose_dimensions(shape.dimensions_size()); + std::iota(shift_transpose_dimensions.begin(), + shift_transpose_dimensions.end(), 0); + shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1, + shape.dimensions_size()); + + // Shift the input by sizeof(T) bytes and apply bit_mask to unpack. + XlaOp shifted_input = ShiftRightLogical( + broadcast_input, Transpose(Broadcast(shift_bits, shape.dimensions()), + shift_transpose_dimensions)); + XlaOp unpack_input = + And(shifted_input, xla::ConstantR0(builder, bit_mask)); + + XlaOp result; + + if (mode_string == "MIN_COMBINED") { + const tensorflow::bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + // result = bfloat16(input + half_range) * scale_factor + range.min + XlaOp unpack_input_bf16 = ConvertElementType(unpack_input, BF16); + XlaOp half_range_bf16 = xla::ConstantR0( + builder, static_cast(half_range)); + XlaOp sum = unpack_input_bf16 + half_range_bf16; + + result = + sum * xla::ConstantR0(builder, scale_factor) + + xla::ConstantR0(builder, range.min); + } else { + // TODO(wangtao): support other modes. + return InvalidArgument( + "Only MIN_COMBINED mode is supported in xla::Dequantize Op."); + } + + std::vector transpose_dimensions(shape.dimensions_size()); + std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1); + std::reverse(transpose_dimensions.begin(), transpose_dimensions.end()); + transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0); + + // Transpose the result to be [dn, unpack_size, dn-1, ..., d1, d0]. + XlaOp transposed_result = Transpose(result, transpose_dimensions); + + // Reshape to be [dn * unpack_size, dn-1, ..., d1, d0]. + XlaOp reshaped_result = Collapse(transposed_result, {0, 1}); + + // Return the transpose result if transpose_output is true. + if (transpose_output) { + return reshaped_result; + } + + // Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size]. + std::vector result_dimensions(shape.dimensions_size()); + std::iota(result_dimensions.begin(), result_dimensions.end(), 0); + std::reverse(result_dimensions.begin(), result_dimensions.end()); + + return Transpose(reshaped_result, result_dimensions); + }); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ diff --git a/tensorflow/compiler/xla/client/lib/quantize_test.cc b/tensorflow/compiler/xla/client/lib/quantize_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..be3603d9e11670913c21a834d2216a999306d582 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/quantize_test.cc @@ -0,0 +1,337 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/quantize.h" + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace { + +using bfloat16 = tensorflow::bfloat16; + +template +std::vector GenerateInput() { + std::vector input; + + for (int64 i = std::numeric_limits::min(); + i < std::numeric_limits::max(); ++i) { + input.push_back(static_cast(i)); + } + + return input; +} + +template +Array2D GenerateLargeSizeInput(int num_columns, int num_rows) { + Array2D input(num_columns, num_rows); + + input.FillRandom(6, 128); + + return input; +} + +template +Array2D PackLargeInput(Array2D &input) { + const int64 size_per_pack = sizeof(uint32) / sizeof(NativeT); + int64 width = input.width(); + + int64 padded_output_width = CeilOfRatio(width, size_per_pack); + + Array2D pack_input(input.height(), padded_output_width); + + for (int h = 0; h < input.height(); h++) { + std::vector input_row; + for (int w = 0; w < width; w++) { + input_row.push_back(input({h, w})); + } + + auto pack_input_vec = PackToUint32(input_row); + + for (int w = 0; w < padded_output_width; w++) { + pack_input(h, w) = pack_input_vec[w]; + } + } + + return pack_input; +} + +template +Array2D GenerateLargeSizeMinCombinedOutput( + Array2D &input, const QuantizedRange &range, + bool transpose_output = false) { + const int64 size_per_pack = sizeof(uint32) / sizeof(NativeT); + int64 width = input.width(); + + int64 padded_output_width = CeilOfRatio(width, size_per_pack) * size_per_pack; + + int64 output_height; + int64 output_width; + + if (transpose_output) { + output_height = padded_output_width; + output_width = input.height(); + } else { + output_height = input.height(); + output_width = padded_output_width; + } + + Array2D output(output_height, output_width, bfloat16(0.0)); + + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max() - + std::numeric_limits::min() + 1)) / + 2.0f; + const bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + + for (int h = 0; h < input.height(); h++) { + std::vector input_row; + for (int w = 0; w < width; w++) { + bfloat16 result = + static_cast(input(h, w) + half_range) * scale_factor + + range.min; + if (transpose_output) { + output(w, h) = result; + } else { + output(h, w) = result; + } + } + } + + return output; +} + +template +std::vector GenerateMinCombinedOutput(const QuantizedRange &range) { + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max() - + std::numeric_limits::min() + 1)) / + 2.0f; + const bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + std::vector output; + for (int64 i = std::numeric_limits::min(); + i < std::numeric_limits::max(); ++i) { + bfloat16 result = + static_cast(i + half_range) * scale_factor + range.min; + output.push_back(result); + } + + const int64 pack_size = sizeof(uint32) / sizeof(NativeT); + const int64 output_size = output.size(); + + int64 num_tailing_zeros = + CeilOfRatio(output_size, pack_size) * pack_size - output_size; + + output.insert(output.end(), num_tailing_zeros, bfloat16(0.0)); + return output; +} + +// TODO(wangtao): add a test to make sure this op is the inverse of the existing +// TF quantize op defined in: third_party/tensorflow/core/kernels/quantize_op.cc + +using DequantizeTest = ClientLibraryTestBase; + +TEST(PackTest, PackUint8ToUint32) { + std::vector input = {0xAB, 0x0B, 0x00, 0xF0, 0x01}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0xAB0B00F0, 0x01000000)); +} + +TEST(PackTest, PackInt8ToUint32) { + std::vector input = {static_cast(0x81), 0x0B, 0x00, 0x20, + 0x01}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0x810B0020, 0x01000000)); +} + +TEST(PackTest, PackUint8ToUint32PerfectSize) { + std::vector input = {3, 2, 1, 0}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0x03020100)); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint16R1) { + XlaBuilder builder(TestName()); + auto input = GenerateInput(); + auto x = ConstantR1(&builder, PackToUint32(input)); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + auto expected = GenerateMinCombinedOutput(range); + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R1) { + XlaBuilder builder(TestName()); + auto input = GenerateInput(); + auto x = ConstantR1(&builder, PackToUint32(input)); + QuantizedRange range(0, 127.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + auto expected = GenerateMinCombinedOutput(range); + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11}, + {12, 13, 16, 15}, + }; + auto x = ConstantR2(&builder, {{PackToUint32(input[0])[0]}, + {PackToUint32(input[1])[0]}, + {PackToUint32(input[2])[0]}, + {PackToUint32(input[3])[0]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + const Array2D expected = { + {bfloat16(0.0), bfloat16(1.0), bfloat16(2.0), bfloat16(3.0)}, + {bfloat16(4.0), bfloat16(5.0), bfloat16(6.0), bfloat16(7.0)}, + {bfloat16(8.0), bfloat16(9.0), bfloat16(10.0), bfloat16(11.0)}, + {bfloat16(12.0), bfloat16(13.0), bfloat16(16.0), bfloat16(15.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TransposeOutput) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11}, + {12, 13, 16, 15}, + }; + auto x = ConstantR2(&builder, {{PackToUint32(input[0])[0]}, + {PackToUint32(input[1])[0]}, + {PackToUint32(input[2])[0]}, + {PackToUint32(input[3])[0]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + const Array2D expected = { + {bfloat16(0.0), bfloat16(4.0), bfloat16(8.0), bfloat16(12.0)}, + {bfloat16(1.0), bfloat16(5.0), bfloat16(9.0), bfloat16(13.0)}, + {bfloat16(2.0), bfloat16(6.0), bfloat16(10.0), bfloat16(16.0)}, + {bfloat16(3.0), bfloat16(7.0), bfloat16(11.0), bfloat16(15.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TailingZero) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3, 16}, + {4, 5, 6, 7, 17}, + {8, 9, 10, 11, 18}, + {12, 13, 16, 15, 19}, + }; + auto x = ConstantR2( + &builder, + {{PackToUint32(input[0])[0], PackToUint32(input[0])[1]}, + {PackToUint32(input[1])[0], PackToUint32(input[1])[1]}, + {PackToUint32(input[2])[0], PackToUint32(input[2])[1]}, + {PackToUint32(input[3])[0], PackToUint32(input[3])[1]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + + const Array2D expected = { + {bfloat16(0.0), bfloat16(1.0), bfloat16(2.0), bfloat16(3.0), + bfloat16(16.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(4.0), bfloat16(5.0), bfloat16(6.0), bfloat16(7.0), + bfloat16(17.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(8.0), bfloat16(9.0), bfloat16(10.0), bfloat16(11.0), + bfloat16(18.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(12.0), bfloat16(13.0), bfloat16(16.0), bfloat16(15.0), + bfloat16(19.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TailingZeroTransposeOutput) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3, 16}, + {4, 5, 6, 7, 17}, + {8, 9, 10, 11, 18}, + {12, 13, 16, 15, 19}, + }; + auto x = ConstantR2( + &builder, + {{PackToUint32(input[0])[0], PackToUint32(input[0])[1]}, + {PackToUint32(input[1])[0], PackToUint32(input[1])[1]}, + {PackToUint32(input[2])[0], PackToUint32(input[2])[1]}, + {PackToUint32(input[3])[0], PackToUint32(input[3])[1]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + + const Array2D expected = { + {bfloat16(0.0), bfloat16(4.0), bfloat16(8.0), bfloat16(12.0)}, + {bfloat16(1.0), bfloat16(5.0), bfloat16(9.0), bfloat16(13.0)}, + {bfloat16(2.0), bfloat16(6.0), bfloat16(10.0), bfloat16(16.0)}, + {bfloat16(3.0), bfloat16(7.0), bfloat16(11.0), bfloat16(15.0)}, + {bfloat16(16.0), bfloat16(17.0), bfloat16(18.0), bfloat16(19.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8LargeSizeTest) { + XlaBuilder builder(TestName()); + Array2D input = GenerateLargeSizeInput(500, 3547); + Array2D input_packed = PackLargeInput(input); + + auto x = ConstantR2FromArray2D(&builder, input_packed); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + + const Array2D expected = + GenerateLargeSizeMinCombinedOutput(input, range); + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8LargeSizeTestTransposeOutput) { + XlaBuilder builder(TestName()); + Array2D input = GenerateLargeSizeInput(500, 3547); + Array2D input_packed = PackLargeInput(input); + + auto x = ConstantR2FromArray2D(&builder, input_packed); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + + const Array2D expected = GenerateLargeSizeMinCombinedOutput( + input, range, /*transpose_output=*/true); + ComputeAndCompareR2(&builder, expected, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc new file mode 100644 index 0000000000000000000000000000000000000000..546127e4627f1717913d1039be13fd0c655be1a3 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc @@ -0,0 +1,471 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" + +#include +#include + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +// Jacobi rotation (also known as Givens rotation): +// G = [[ c, s], +// [-s, c]] +// matmul(G_T, G) = I +struct SymmetricSchurDecomposition { + XlaOp c; // cosine. + XlaOp s; // sine. +}; + +// JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix +// and the off-diagonal norm of the rotated matrix. After each Jacobi iteration, +// off-diagonal norm is reduced. +struct JacobiUpdate { + XlaOp v; + XlaOp w; +}; + +struct FrobeniusNorms { + XlaOp off_diagonal_norm; + XlaOp total_norm; +}; + +// Given an n-by-n symmetric A and integers p and q that satisfy 0 <= p < q < n, +// it computes a rotation matrix G = [[c, s], [-s, c]], such that +// G_T * A[[p, q], [p, q]] * G +// is diagonalized. +// +// def sym_schur2x2(A, p, q): +// if np.abs(A[p, q]) > 1e-6: +// tau = (A[q, q] - A[p, p]) / (2 * A[p, q]) +// if tau >= 0: +// t = 1.0 / (tau + np.sqrt(1 + tau ** 2)) +// else: +// t = -1.0 / (-tau + np.sqrt(1 + tau ** 2)) +// c = 1.0 / np.sqrt(1.0 + t ** 2) +// s = t * c +// else: +// c = 1.0 +// s = 0.0 +// return c, s +StatusOr SymmetricShurDecomposition2x2(XlaOp a, + XlaOp p, + XlaOp q, + XlaOp tol) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + + auto zero = ScalarLike(a, 0.0); + auto one = ScalarLike(a, 1.0); + auto two = ScalarLike(a, 2.0); + + auto pqs = DynamicSliceInMinorDims(a, {p, q}, {1, 1}); + + auto ps = DynamicSliceInMinorDims(a, {p, p}, {1, 1}); + auto qs = DynamicSliceInMinorDims(a, {q, q}, {1, 1}); + + auto tau = (qs - ps) / (pqs * two); + auto t_pos = one / (tau + Sqrt(one + Square(tau))); + auto t_neg = -one / (-tau + Sqrt(one + Square(tau))); + auto t = Select(Ge(tau, zero), t_pos, t_neg); + + auto c_temp = Rsqrt(one + Square(t)); + auto s_temp = t * c_temp; + + auto c = Select(Ge(Abs(pqs), tol), c_temp, ZerosLike(c_temp) + one); + auto s = Select(Ge(Abs(pqs), tol), s_temp, ZerosLike(s_temp)); + // Renormalize c and s to compensate for low precision arithmetic, this step + // is redundant if high precision float is used, like float64. + auto rnorm = Rsqrt(Square(c) + Square(s)); + + SymmetricSchurDecomposition schur; + + schur.c = c * rnorm; + schur.s = s * rnorm; + + return schur; +} + +StatusOr Update(JacobiUpdate jacobi_update, XlaOp p, XlaOp q, + XlaOp tol, int64 n) { + XlaBuilder* builder = jacobi_update.w.builder(); + TF_ASSIGN_OR_RETURN( + SymmetricSchurDecomposition schur, + SymmetricShurDecomposition2x2(jacobi_update.w, p, q, tol)); + + TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(jacobi_update.w)); + const std::vector batch_dims(w_shape.dimensions().begin(), + w_shape.dimensions().end() - 2); + const int64 num_dims = w_shape.rank(); + + auto zero = ScalarLike(p, 0); + + XlaOp c = schur.c; + XlaOp s = schur.s; + + auto slice_p = DynamicSliceInMinorDims(jacobi_update.w, {p, zero}, {1, n}); + auto slice_q = DynamicSliceInMinorDims(jacobi_update.w, {q, zero}, {1, n}); + + auto slice_p_new = c * slice_p - s * slice_q; + auto slice_q_new = s * slice_p + c * slice_q; + + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_p_new, {p, zero}); + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_q_new, {q, zero}); + + slice_p = DynamicSliceInMinorDims(jacobi_update.w, {zero, p}, {n, 1}); + slice_q = DynamicSliceInMinorDims(jacobi_update.w, {zero, q}, {n, 1}); + + slice_p_new = c * slice_p - s * slice_q; + slice_q_new = s * slice_p + c * slice_q; + + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_p_new, {zero, p}); + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_q_new, {zero, q}); + + // Zero out a_{pq} explicitly. + std::vector pq_dims(batch_dims.begin(), batch_dims.end()); + pq_dims.push_back(1); + pq_dims.push_back(1); + auto pq_zero = ScalarLike(jacobi_update.w, 0.0); + auto pq_zeros = Broadcast(pq_zero, pq_dims); + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, pq_zeros, {p, q}); + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, pq_zeros, {q, p}); + + slice_p = DynamicSliceInMinorDims(jacobi_update.v, {zero, p}, {n, 1}); + slice_q = DynamicSliceInMinorDims(jacobi_update.v, {zero, q}, {n, 1}); + + std::vector broadcast_dims(batch_dims.size()); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims.push_back(num_dims - 1); + + // Renormalize the p-th and q-th columns. This step is redundant if high + // precision floats are used, like 64-bit float. But for 32-bit float, it + // becomes necessary. This step will not increase the overall complexity. + slice_p_new = c * slice_p - s * slice_q; + slice_p_new = Mul( + slice_p_new, + Rsqrt(Reduce(Square(slice_p_new), pq_zero, + CreateScalarAddComputation(w_shape.element_type(), builder), + {num_dims - 2})), + broadcast_dims); + slice_q_new = s * slice_p + c * slice_q; + slice_q_new = Mul( + slice_q_new, + Rsqrt(Reduce(Square(slice_q_new), pq_zero, + CreateScalarAddComputation(w_shape.element_type(), builder), + {num_dims - 2})), + broadcast_dims); + + jacobi_update.v = + DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_p_new, {zero, p}); + jacobi_update.v = + DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_q_new, {zero, q}); + + return jacobi_update; +} + +StatusOr ComputeFrobeniusNorms(XlaOp w) { + XlaBuilder* builder = w.builder(); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w)); + const int64 num_dims = shape.rank(); + auto frobenius_norm = + Sqrt(Reduce(Square(w), ScalarLike(w, 0.0), + CreateScalarAddComputation(shape.element_type(), builder), + {num_dims - 2, num_dims - 1})); + auto diag = GetMatrixDiagonal(w); + auto diag_square = + Reduce(Square(diag), ScalarLike(w, 0.0), + CreateScalarAddComputation(shape.element_type(), builder), + {num_dims - 2}); + + FrobeniusNorms frobenius_norms; + + frobenius_norms.off_diagonal_norm = + Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w, 0.0))); + frobenius_norms.total_norm = frobenius_norm; + + return frobenius_norms; +} + +StatusOr> WhileLoopFn( + absl::Span initial_values, // + int matrix_dimension, // + int max_sweep_updates, // + PrimitiveType index_type, // + absl::string_view name, // + XlaBuilder* builder) { + auto while_cond_fn = [&](absl::Span values, + XlaBuilder* cond_builder) -> StatusOr { + auto k = values[0]; + auto max_sweeps = ScalarLike(k, max_sweep_updates); + auto sweep_update_cond = Gt(max_sweeps, k); + + auto norms = ComputeFrobeniusNorms(values[2]).ValueOrDie(); + auto tol = norms.total_norm * values[3]; + auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm), + xla::ConstantR0(cond_builder, false), + CreateScalarOrComputation(PRED, cond_builder)); + + return And(sweep_update_cond, tol_cond); + }; + + auto while_body_fn = + [&](absl::Span values, + XlaBuilder* body_builder) -> StatusOr> { + auto while_cond_fn_inner = + [&](absl::Span values_inner, + XlaBuilder* inner_cond_builder) -> StatusOr { + auto p = values_inner[0]; + return Lt(p, ScalarLike(p, matrix_dimension - 1)); + }; + + auto while_body_fn_inner = + [&](absl::Span values_inner, + XlaBuilder* inner_body_builder) -> StatusOr> { + auto while_cond_fn_innermost = + [&](absl::Span values_innermost, + XlaBuilder* innermost_cond_builder) -> StatusOr { + auto q = values_innermost[1]; + return Lt(q, ScalarLike(q, matrix_dimension)); + }; + auto while_body_fn_innermost = + [&](absl::Span values_innermost, + XlaBuilder* innermost_body_builder) + -> StatusOr> { + auto p = values_innermost[0]; + auto q = values_innermost[1]; + + JacobiUpdate jacobi_update; + jacobi_update.v = values_innermost[2]; + jacobi_update.w = values_innermost[3]; + + auto tol = values_innermost[4]; + + TF_ASSIGN_OR_RETURN(jacobi_update, + Update(jacobi_update, p, q, tol, matrix_dimension)); + + std::vector updated_values_innermost; + updated_values_innermost.reserve(values_innermost.size()); + + updated_values_innermost.push_back(p); + updated_values_innermost.push_back(q + ScalarLike(q, 1)); + updated_values_innermost.push_back(jacobi_update.v); + updated_values_innermost.push_back(jacobi_update.w); + updated_values_innermost.push_back(tol); + + return updated_values_innermost; + }; + + std::vector values_innermost(5); + auto p = values_inner[0]; + auto q = p + ScalarLike(p, 1); + values_innermost[0] = p; // index p. + values_innermost[1] = q; // index q. + values_innermost[2] = values_inner[1]; // v. + values_innermost[3] = values_inner[2]; // w. + values_innermost[4] = values_inner[3]; // tol. + TF_ASSIGN_OR_RETURN( + values_innermost, + WhileLoopHelper(while_cond_fn_innermost, while_body_fn_innermost, + values_innermost, absl::StrCat(name, "-Innermost"), + inner_body_builder)); + + std::vector updated_values_inner; + updated_values_inner.reserve(values_inner.size()); + + updated_values_inner.push_back(p + ScalarLike(p, 1)); + updated_values_inner.push_back(values_innermost[2]); + updated_values_inner.push_back(values_innermost[3]); + updated_values_inner.push_back(values_innermost[4]); + return updated_values_inner; + }; + // Indexes. + XlaOp k = values[0]; + + std::vector values_inner(4); + values_inner[0] = ScalarLike(k, 0); // index p. + values_inner[1] = values[1]; // v. + values_inner[2] = values[2]; // w. + values_inner[3] = values[3]; // tol. + TF_ASSIGN_OR_RETURN( + values_inner, + WhileLoopHelper(while_cond_fn_inner, while_body_fn_inner, values_inner, + absl::StrCat(name, "-Inner"), body_builder)); + + std::vector updated_values; + updated_values.reserve(values_inner.size()); + + updated_values.push_back(k + ScalarLike(k, 1)); + updated_values.push_back(values_inner[1]); + updated_values.push_back(values_inner[2]); + updated_values.push_back(values_inner[3]); + + return updated_values; + }; + std::vector values; + TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn, + initial_values, name, builder)); + + return values; +} + +StatusOr SortByEigenvalues(SelfAdjointEigResult result) { + XlaBuilder* builder = result.v.builder(); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.v)); + const int64 num_dims = shape.rank(); + auto dimensions = shape.dimensions(); + + std::vector broadcast_dims(num_dims - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims[num_dims - 2] = num_dims - 1; + result.w = BroadcastInDim(result.w, dimensions, broadcast_dims); + + XlaOp sort_result = + Sort({result.w, result.v}, + CreateScalarLtComputation( + {shape.element_type(), shape.element_type()}, builder), + num_dims - 1); + result.w = GetMatrixDiagonal(GetTupleElement(sort_result, 0)); + result.v = GetTupleElement(sort_result, 1); + return result; +} + +} // namespace + +// This is the cyclic Jacobi iteration. Please note that the eigenvalues are +// possibly not ordered. +// +// def jacobi(A): +// n, _ = A.shape +// V = np.eye(n) +// frobenius_norm = np.linalg.norm(A) +// diag_norm = np.linalg.norm(np.diag(A)) +// off_diag_norm = np.sqrt( +// frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm) +// while off_diag_norm > 1e-6 * frobenius_norm: +// for p in range(n - 1): +// for q in range(p + 1, n): +// c, s = sym_schur2x2(A, p, q) +// A[[p, q], :] = np.matmul(np.array([[c, -s], [s, c]]), +// A[[p, q], :]) +// A[:, [p, q]] = np.matmul(A[:, [p, q]], +// np.array([[c, s], [-s, c]])) +// V[:, [p, q]] = np.matmul(V[:, [p, q]], +// np.array([[c, s], [-s, c]])) +// frobenius_norm_sq = np.linalg.norm(A) +// diag_square_sum = np.linalg.norm(np.diag(A)) +// off_diag_norm = np.sqrt( +// frobenius_norm - diag_norm) * np.sqrt( +// frobenius_norm + diag_norm) +// +// return A, V +// +// TODO(kuny): Implement parallel order Jacobi. +// +SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter, + float epsilon) { + XlaBuilder* builder = a.builder(); + auto return_error = [&](const Status& status) { + SelfAdjointEigResult result; + result.v = builder->ReportError(status); + result.w = builder->ReportError(status); + return result; + }; + auto shape_with_status = builder->GetShape(a); + if (!shape_with_status.status().ok()) { + return return_error(shape_with_status.status()); + } + Shape a_shape = shape_with_status.ValueOrDie(); + const int64 num_dims = a_shape.rank(); + if (num_dims < 2) { + return return_error(InvalidArgument( + "Arguments to Eigen decomposition must have rank >= 2: got shape %s.", + a_shape.ToString())); + } + PrimitiveType type = a_shape.element_type(); + if (!primitive_util::IsFloatingPointType(type)) { + return return_error(InvalidArgument( + "Type of the input matrix must be float: got %s.", a_shape.ToString())); + } + + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + + if (m != n) { + return return_error(InvalidArgument( + "Arguments to Eigen decomposition must be square matrices: got shape " + "(%d, %d).", + m, n)); + } + + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); + } + + auto tol = ScalarLike(a, epsilon); + + auto v_init = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); + auto w_init = Triangle(a, lower); + w_init = w_init + TransposeInMinorDims(w_init) - w_init * v_init; + + auto output_with_status = WhileLoopFn( + { + Zero(builder, S32), // k + v_init, // v + w_init, // w + tol, // + }, // + n, // + max_iter, // + S32, // + "CyclicJacobi", // + builder); + if (!output_with_status.status().ok()) { + return return_error(output_with_status.status()); + } + + auto output = output_with_status.ValueOrDie(); + + SelfAdjointEigResult result; + result.v = output[1]; + result.w = GetMatrixDiagonal(output[2]); + + return SortByEigenvalues(result).ValueOrDie(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h new file mode 100644 index 0000000000000000000000000000000000000000..2a089891d6a2d80c0c265a3310539b4f1c5db4d5 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h @@ -0,0 +1,40 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// The eigenvalue decomposition of a symmetric matrix, the original matrix is +// recovered by v * w * v_t. +struct SelfAdjointEigResult { + // The i-th column is the normalized eigenvector corresponding to the + // eigenvalue w[i]. Will return a matrix object if a is a matrix object. + XlaOp v; + // The eigenvalues in ascending order, each repeated according to its + // multiplicity. + XlaOp w; +}; + +SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower = true, + int64 max_iter = 100, float epsilon = 1e-6); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c8875dff7bfdbd4e133297cef0a6686bfcd9bb6f --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc @@ -0,0 +1,313 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { + +class SelfAdjointEigTest : public ClientLibraryTestBase { + protected: + void SetUp() override { + ClientLibraryTestBase::SetUp(); + batch_3d_4x4_ = Array3D{ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 100, 6}, + {12, 48, 6, 62}, + }, + }; + matrix2d_8x8_ = Array2D{ + {14., 123., 49., 112., 115., 173., 182., 125.}, + {123., 14., 60., 118., 150., 130., 91., 72.}, + {49., 60., 138., 111., 106., 101., 115., 142.}, + {112., 118., 111., 142., 91., 130., 25., 61.}, + {115., 150., 106., 91., 116., 121., 128., 85.}, + {173., 130., 101., 130., 121., 70., 151., 132.}, + {182., 91., 115., 25., 128., 151., 66., 92.}, + {125., 72., 142., 61., 85., 132., 92., 156.}, + }; + low_rank_4x4_ = Array2D{ + // x = [[1, 2, 3, 4], [1, -1, 1, -1]] + // matmul(x.T, x) + {2, 1, 4, 3}, + {1, 5, 5, 9}, + {4, 5, 10, 11}, + {3, 9, 11, 17}, + }; + } + void TearDown() override { ClientLibraryTestBase::TearDown(); } + + Array3D GetUnitMatrix3D(const Array3D& matrix) { + Array3D result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0); + for (int i = 0; i < matrix.n1(); ++i) { + for (int j = 0; j < matrix.n2(); ++j) { + result({i, j, j}) = 1.0; + } + } + return result; + } + + Array3D ExtractTriangularMatrix(const Array3D& matrix, + bool lower) { + Array3D result(matrix); + for (int i = 0; i < result.n1(); ++i) { + for (int j = 0; j < result.n2(); ++j) { + if (lower) { + for (int k = j + 1; k < result.n3(); ++k) { + result({i, j, k}) = 0.0; + } + } else { + for (int k = 0; k < j; ++k) { + result({i, j, k}) = 0.0; + } + } + } + } + return result; + } + + XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) { + Shape shape = builder->GetShape(result.v).ValueOrDie(); + std::vector out_dims = shape.dimensions(); + std::vector broadcast_dims(shape.rank() - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + + broadcast_dims[shape.rank() - 2] = shape.rank() - 1; + auto vw = Mul(result.v, BroadcastInDim(result.w, out_dims, broadcast_dims)); + return BatchDot(vw, TransposeInMinorDims(result.v), + PrecisionConfig::HIGHEST); + } + + XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) { + Shape shape = builder->GetShape(m1).ValueOrDie(); + int64 size = 1; + for (auto d : shape.dimensions()) { + size *= d; + } + return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0), + CreateScalarAddComputation(F32, builder)) / + ConstantR0WithType(builder, F32, size); + } + + Array2D GenerateRandomSymmetricMatrix(int size) { + Array2D result{size, size, 0.0}; + result.FillRandom(10 /* stddev */, 2 /* mean */); + for (int i = 0; i < size; ++i) { + for (int j = 0; j < i; ++j) { + result({j, i}) = result({i, j}); + } + } + return result; + } + + Array3D batch_3d_4x4_; + Array2D matrix2d_8x8_; + Array2D low_rank_4x4_; + Array2D wrong_type_4x4_; +}; + +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter(batch_3d_4x4_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + ComputeMatmulVWVt(result, &builder); + + ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter( + ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + ComputeMatmulVWVt(result, &builder); + + ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter( + ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a); + auto result = SelfAdjointEig(a, false); + ComputeMatmulVWVt(result, &builder); + + ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter(batch_3d_4x4_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST); + + ComputeAndCompareR3(&builder, GetUnitMatrix3D(batch_3d_4x4_), + {a_data.get()}, ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR2Parameter(low_rank_4x4_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + ComputeMatmulVWVt(result, &builder); + + ComputeAndCompareR2(&builder, low_rank_4x4_, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) { + XlaBuilder builder(TestName()); + + // This is computed by numpy.linalg.eigh with float32. + std::vector expected{-182.69205, -116.86245, -105.74489, -9.545369, + 37.81711, 104.732285, 120.29153, 868.00385}; + + XlaOp a; + auto a_data = CreateR2Parameter(matrix2d_8x8_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + Add(result.w, ZerosLike(result.w)); + + ComputeAndCompareR1(&builder, expected, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) { + XlaBuilder builder(TestName()); + + float expected_vals = 1e-3; + + XlaOp a; + auto a_data = CreateR2Parameter(matrix2d_8x8_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + // np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2 + GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8), + BatchDot(TransposeInMinorDims(result.v), result.v), + &builder); + + ComputeAndCompareR0(&builder, expected_vals, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR2Parameter(wrong_type_4x4_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + EXPECT_FALSE(result.v.valid()); + EXPECT_FALSE(result.w.valid()); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_8x8) { + XlaBuilder builder(TestName()); + int size = 8; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_16x16) { + XlaBuilder builder(TestName()); + int size = 16; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_32x32) { + XlaBuilder builder(TestName()); + int size = 32; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_256x256) { + XlaBuilder builder(TestName()); + int size = 256; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_512x512) { + XlaBuilder builder(TestName()); + int size = 512; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc index f8c7df3ff5189c817202eaf39adb572f7e232ec2..d7b33c5af25606c4e7e443027b913f7ca13a013c 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.cc +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace xla { @@ -26,7 +27,7 @@ XlaOp SliceInMinorDims(XlaOp x, absl::Span start, TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); TF_RET_CHECK(n_minor_dims <= n_dims); auto major_dims = AsInt64Slice(shape.dimensions()) .subspan( @@ -51,17 +52,17 @@ XlaOp SliceInMinorDims(XlaOp x, absl::Span start, XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = shape.rank(); + TF_RET_CHECK(start.size() == n_dims); + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. std::vector start_as_int32(start.begin(), start.end()); - auto start_constant = ConstantR1(builder, start_as_int32); - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); - TF_ASSIGN_OR_RETURN(Shape start_constant_shape, - builder->GetShape(start_constant)); - const int64 start_length = - ShapeUtil::GetDimension(start_constant_shape, -1); - TF_RET_CHECK(start_length == n_dims); - return DynamicUpdateSlice(x, update, start_constant); + std::vector start_ops(start.size()); + for (int i = 0; i < start.size(); ++i) { + start_ops[i] = ConstantR0(builder, start_as_int32[i]); + } + return DynamicUpdateSlice(x, update, start_ops); }); } @@ -70,7 +71,7 @@ XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); const int64 n_minor_dims = start.size(); TF_RET_CHECK(n_minor_dims <= n_dims); std::vector padded_start(n_dims, 0); @@ -90,18 +91,17 @@ std::vector ConcatVectors(absl::Span xs, return output; } -XlaOp PrependZerosInMajorDims(XlaOp x, absl::Span starts) { +StatusOr> PrependZerosInMajorDims( + XlaOp x, absl::Span starts) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); - auto zero = Reshape(ConstantR0(builder, 0), {1}); - std::vector padded_starts(n_dims, zero); - for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = Reshape(starts[i], {1}); - } - return ConcatInDim(builder, padded_starts, 0); - }); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = shape.rank(); + auto zero = ConstantR0(builder, 0); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = starts[i]; + } + return padded_starts; } } // namespace @@ -111,7 +111,7 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); int64 n_minor_dims = starts.size(); TF_RET_CHECK(n_minor_dims == sizes.size()); TF_RET_CHECK(n_minor_dims <= n_dims); @@ -119,7 +119,7 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, .subspan( /*pos=*/0, /*len=*/n_dims - sizes.size()); - auto padded_starts = PrependZerosInMajorDims(x, starts); + TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts)); auto padded_sizes = ConcatVectors(major_dims, sizes); return DynamicSlice(x, padded_starts, padded_sizes); }); @@ -127,8 +127,38 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, absl::Span starts) { - auto padded_starts = PrependZerosInMajorDims(x, starts); - return DynamicUpdateSlice(x, update, padded_starts); + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts)); + return DynamicUpdateSlice(x, update, padded_starts); + }); +} + +XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim) { + XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index)); + ShapeUtil::AppendMajorDimension(1, &index_shape); + std::vector to_concat; + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + to_concat.reserve(input_shape.rank()); + for (int64 i = 0; i < input_shape.rank(); ++i) { + if (i == dim) { + to_concat.push_back(Reshape(index, index_shape.dimensions())); + } else { + to_concat.push_back(Iota(builder, index_shape, i)); + } + } + XlaOp gather_indices = ConcatInDim(builder, to_concat, input_shape.rank()); + std::vector slice_sizes(input_shape.rank(), 1); + GatherDimensionNumbers gather_dnums; + gather_dnums.set_index_vector_dim(input_shape.rank()); + for (int64 i = 0; i < input_shape.rank(); ++i) { + gather_dnums.add_collapsed_slice_dims(i); + gather_dnums.add_start_index_map(i); + } + return Gather(input, gather_indices, gather_dnums, slice_sizes); + }); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h index 6c482a38b5489c9fb17c3dca9ee3d2a1b8fd1890..69f98a6f43fa167adf6f77b28645a3460b292633 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.h +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -43,6 +43,20 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, absl::Span starts); +// Gathers values along an axis specified by dim. +// +// For a 3-D tensor the output is specified by: +// +// out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 +// out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 +// out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 +// +// If `input` is an n-dimensional tensor with size +// [X0,X1,X2,..XN] and dim = i `index` must be an n-dimensional tensor with size +// [X0,X1,...Y,Xi+1,...,X[N] where y >= 1 and `out` will have the same sizes as +// `index`. +XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ diff --git a/tensorflow/compiler/xla/client/lib/slicing_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc index 8d362119e01006555db0f82d02626175936e1d05..db6ebb9df18372260a64a3e9fd17b0c30b35667d 100644 --- a/tensorflow/compiler/xla/client/lib/slicing_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -102,5 +102,18 @@ XLA_TEST_F(SlicingTest, SimpleSliceUpdate) { {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); } +XLA_TEST_F(SlicingTest, TorchGather) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp input, index; + auto input_data = + CreateR2Parameter({{1, 2}, {3, 4}}, 0, "input", &builder, &input); + auto index_data = + CreateR2Parameter({{0, 0}, {1, 0}}, 1, "index", &builder, &index); + TorchGather(input, index, 1); + + ComputeAndCompareR2(&builder, {{1, 1}, {4, 3}}, + {input_data.get(), index_data.get()}); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index e8553a08bb014e790822a14e128686b60b8d6b7c..ddc39f4d874cd3613a763b969091e7e65ff1c783 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -30,7 +31,13 @@ XlaOp TopK(XlaOp input, int64 k) { ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions())); XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); auto input_dims = input_shape.dimensions(); - XlaOp sort_result = Sort(Neg(input), {iota_s32}); + // TODO(b/122298745): Get rid of Neg() and use CreateScalarGtComputation + // once the TPU backend supports the comparison computations. + XlaOp sort_result = + Sort({Neg(input), iota_s32}, + CreateScalarLtComputation({input_shape.element_type(), S32}, + iota_s32.builder()), + last_dim, /*is_stable=*/true); std::vector start_indices(input_shape.dimensions_size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); limit_indices[last_dim] = k; diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index 27ff36c7491ab8397d46f3a49493ff2b904deb2d..0fbd138aca1e86f219d0459086fc09d20844f135 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -77,7 +77,7 @@ XLA_TEST_F(SortingTest, TopKFullSort) { auto x = ConstantR1(&builder, inputs); xla::GetTupleElement(xla::TopK(x, kSize), 0); - std::sort(inputs.begin(), inputs.end(), std::greater()); + absl::c_sort(inputs, std::greater()); ComputeAndCompareR1(&builder, inputs, {}); } diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index a95bbf2c8c860914877d3195b97342097dafc725..9f520bcdadfabc8ca9f9ee82b20804fd2c50d1db 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -34,7 +34,7 @@ namespace { // specified shape. In case of a (nested) tuple shape this is the total byte // size of all sub-shapes within the tuple. int64 DataSizeOfShape(const Shape& shape) { - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { return ShapeUtil::ByteSizeOf(shape); } @@ -47,7 +47,7 @@ int64 DataSizeOfShape(const Shape& shape) { // Creates a XlaOp for an op what generates fake data with the given shape. XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { return Broadcast( ConstantLiteral(builder, LiteralUtil::One(shape.element_type())), AsInt64Slice(shape.dimensions())); @@ -59,22 +59,25 @@ XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { return Tuple(builder, parts); } -std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, - Client* client) { +std::unique_ptr MakeFakeDataViaDeviceOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts) { XlaBuilder b(absl::StrCat("make_fake_", ShapeUtil::HumanString(shape))); BuildFakeDataOpOnDevice(shape, &b); XlaComputation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); *execution_options.mutable_shape_with_output_layout() = shape.ToProto(); + if (debug_opts) { + *execution_options.mutable_debug_options() = *debug_opts; + } return client->Execute(computation, /*arguments=*/{}, &execution_options) .ConsumeValueOrDie(); } } // namespace -std::unique_ptr MakeFakeDataOrDie(const Shape& shape, - Client* client) { +std::unique_ptr MakeFakeDataOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts /*=nullptr*/) { if (DataSizeOfShape(shape) < (1LL << 20)) { StatusOr literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { @@ -82,24 +85,25 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, // an on-device computation. CHECK_EQ(literal_status.status().code(), tensorflow::error::UNIMPLEMENTED); - return MakeFakeDataViaDeviceOrDie(shape, client); + return MakeFakeDataViaDeviceOrDie(shape, client, debug_opts); } return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie(); } // If the data is large, generate it on-device. - return MakeFakeDataViaDeviceOrDie(shape, client); + return MakeFakeDataViaDeviceOrDie(shape, client, debug_opts); } std::vector> MakeFakeArgumentsOrDie( - const XlaComputation& computation, Client* client) { + const XlaComputation& computation, Client* client, + DebugOptions* debug_opts /*=nullptr*/) { CHECK(computation.proto().has_host_program_shape()) << "Computation should have progran shape."; auto program_shape = computation.proto().host_program_shape(); std::vector> results; for (const ShapeProto& shape : program_shape.parameters()) { - results.push_back(MakeFakeDataOrDie(Shape(shape), client)); + results.push_back(MakeFakeDataOrDie(Shape(shape), client, debug_opts)); } return results; } diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 03695ce2a339735e3e49522f4fe1bbf2d83a3834..428fa3e93d1b46983aae60176e7c2242d2552fdb 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -29,14 +29,19 @@ namespace xla { // Generates fake data of the given shape on the device or dies. The fake data // is created by performing a computation on the device rather than transferring // data from the host to the device. -std::unique_ptr MakeFakeDataOrDie(const Shape& shape, - Client* client); +// +// The optional DebugOptions are used when generating fake data on the device. +std::unique_ptr MakeFakeDataOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts = nullptr); // Returns vector of GlobalData handles of fake data (created using // MakeFakeDataOrDie) that are correctly shaped arguments for the given // xla computation. +// +// The optional DebugOptions are used when generating fake data on the device. std::vector> MakeFakeArgumentsOrDie( - const XlaComputation& computation, Client* client); + const XlaComputation& computation, Client* client, + DebugOptions* debug_opts = nullptr); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve.h b/tensorflow/compiler/xla/client/lib/triangular_solve.h deleted file mode 100644 index 50a3b30ebd1c15eb6d2ace4e351cb41f21db7093..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/lib/triangular_solve.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -// Solves systems of linear equations with lower or upper triangular coefficient -// matrices by forward- or back-substitution. Broadcasting along leading -// dimensions, this routine solves one of the matrix systems -// `op(a) * x = b`, or `x * op(a) = b`, -// for the variable `x` given `a` and `b`, where `op(a)` is either -// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`. -// That is, the innermost matrices in the output satisfy a scalar system -// depending on the value of the value of (left_side, transpose_a, conjugate_a) -// according to: -// (F, F, F) => `output[..., i, k] a[..., k, j] = b[..., i, j]`, -// (F, F, T) => `output[..., i, k] a*[..., k, j] = b[..., i, j]`, -// (F, T, F) => `output[..., i, k] a[..., j, k] = b[..., i, j]`, -// (F, T, T) => `output[..., i, k] a*[..., j, k] = b[..., i, j]`, -// (T, F, F) => ` a[..., i, k] output[..., k, j] = b[..., i, j]`, -// (T, F, T) => `a*[..., i, k] output[..., k, j] = b[..., i, j]`, -// (T, T, F) => ` a[..., i, k] output[..., j, k] = b[..., i, j]`, -// (T, T, T) => `a*[..., i, k] output[..., j, k] = b[..., i, j]`, -// where * denotes complex conjugation and where the index `k` is summed over. -// -// `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form -// square matrices. If lower is true (false), then the strictly upper (lower) -// triangular part of each innermost matrix in `a` is assumed to be zero and is -// not accessed. -// `b` is a tensor of shape `[..., M, K]` if left_side is true, otherwise a -// tensor of shape `[..., K, M]`. -// `left_side` is a boolean, indicating whether to solve a system of the form -// op(a) * x = b (true) or x * op(a) = b (false). -// `lower` is a boolean, indicating whether the argument `a` is lower-triangular -// (true) or upper-triangular (false). -// `transpose_a` is a boolean indicating whether the matrix `a` is transposed. -// `conjugate_a` is a boolean indicating whether the entries of `a` are complex -// conjugated (independently of whether they are transposed), so that when both -// transpose_a and conjugate_a are true the effect is a Hermitian adjoint. -// -// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no -// blocking is used. -XlaOp TriangularSolve( - XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a, - bool conjugate_a, int64 block_size = 128, - PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc deleted file mode 100644 index f6a70d64a788d95a456774ccbbcf67f2e5cac98b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc +++ /dev/null @@ -1,333 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" - -#include -#include -#include - -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace xla { -namespace { - -using TriangularSolveTest = xla::ClientLibraryTestBase; -using TriangularSolveLeftLookingTest = xla::ClientLibraryTestBase; -using complex64 = xla::complex64; - -xla::Array2D AValsLower() { - return {{2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}}; -} - -xla::Array2D AValsUpper() { - return {{2, 3, 4, 5}, {0, 6, 7, 8}, {0, 0, 9, 10}, {0, 0, 0, 11}}; -} - -xla::Array2D BValsRight() { - return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; -} - -xla::Array2D BValsLeft() { - return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; -} - -xla::Array2D AValsLowerComplex() { - return {{2, 0, 0, 0}, - {complex64(3, 1), 6, 0, 0}, - {4, complex64(7, 2), 9, 0}, - {5, 8, complex64(10, 3), 11}}; -} - -xla::Array2D AValsUpperComplex() { - return {{2, 3, complex64(4, 3), 5}, - {0, 6, complex64(7, 2), 8}, - {0, 0, complex64(9, 1), 10}, - {0, 0, 0, 11}}; -} - -xla::Array2D BValsRightComplex() { - return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; -} - -xla::Array2D BValsLeftComplex() { - return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; -} - -xla::Array2D AValsFull() { - return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; -} - -XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - TriangularSolve(a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - - xla::Array2D expected({ - {0.5, 0.08333334, 0.04629629, 0.03367003}, - {2.5, -0.25, -0.1388889, -0.1010101}, - {4.5, -0.58333331, -0.32407406, -0.23569024}, - }); - - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - TriangularSolve(a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - - xla::Array2D expected({ - {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, - {0.64393939, 0.06565657, -0.03030303, 0.72727273}, - {1.4520202, 0.2003367, 0.01010101, 1.09090909}, - }); - - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - TriangularSolve(a, b, - /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - - xla::Array2D expected({ - {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, - {0.64393939, 0.06565657, -0.03030303, 0.72727273}, - {1.4520202, 0.2003367, 0.01010101, 1.09090909}, - }); - - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - TriangularSolve(a, b, - /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - - xla::Array2D expected({ - {0.5, 0.08333334, 0.04629629, 0.03367003}, - {2.5, -0.25, -0.1388889, -0.1010101}, - {4.5, -0.58333331, -0.32407406, -0.23569024}, - }); - - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - TriangularSolve(a, b, - /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - - xla::Array2D expected({ - {-0.89646465, -0.69444444, -0.49242424}, - {-0.27441077, -0.24074074, -0.20707071}, - {-0.23232323, -0.22222222, -0.21212121}, - {0.90909091, 1., 1.09090909}, - }); - - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - TriangularSolve(a, b, - /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - - xla::Array2D expected({ - {0.5, 1.0, 1.5}, - {0.41666667, 0.33333333, 0.25}, - {0.23148148, 0.18518519, 0.13888889}, - {0.16835017, 0.13468013, 0.1010101}, - }); - - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - TriangularSolve(a, b, - /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/3); - - xla::Array2D expected({ - {0.5, 1.0, 1.5}, - {0.41666667, 0.33333333, 0.25}, - {0.23148148, 0.18518519, 0.13888889}, - {0.16835017, 0.13468013, 0.1010101}, - }); - - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - TriangularSolve(a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - - xla::Array2D expected({ - {0.5, 1.0, 1.5}, - {0.41666667, 0.33333333, 0.25}, - {0.23148148, 0.18518519, 0.13888889}, - {0.16835017, 0.13468013, 0.1010101}, - }); - - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - TriangularSolve(a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - - xla::Array2D expected({ - {-0.89646465, -0.69444444, -0.49242424}, - {-0.27441077, -0.24074074, -0.20707071}, - {-0.23232323, -0.22222222, -0.21212121}, - {0.90909091, 1., 1.09090909}, - }); - - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = - CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); - auto b_data = - CreateR2Parameter(BValsRightComplex(), 1, "b", &builder, &b); - TriangularSolve(a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/true, - /*block_size=*/2); - - xla::Array2D expected({ - {0.5, complex64(0.08333333, 0.08333333), - complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)}, - {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963), - complex64(0.08670034, -0.02104377)}, - {4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296), - complex64(0.11026936, -0.03114478)}, - }); - - ComputeAndCompareR2(&builder, expected, - {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = - CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); - auto b_data = - CreateR2Parameter(BValsLeftComplex(), 1, "b", &builder, &b); - TriangularSolve(a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - - xla::Array2D expected({ - {0.5, 1., 1.5}, - {0.41666667, 0.33333333, 0.25}, - {complex64(0.20020325, -2.81504065e-01), - complex64(0.13821138, -4.22764228e-01), - complex64(0.07621951, -5.64024390e-01)}, - {complex64(0.19678492, 2.55912786e-01), - complex64(0.17738359, 3.84331116e-01), - complex64(0.15798226, 5.12749446e-01)}, - }); - - ComputeAndCompareR2(&builder, expected, - {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 049cd15738a619294b19d5cf74ca514d7b4a00ad..48b5f94538f453785194bc434a91ee0a10c020c2 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -164,9 +164,8 @@ StatusOr LocalExecutable::Run( // ExecutableRunOptions.eigen_intra_op_thread_pool. // *) The thread pool used for XLA CPU ops is from // backend_->eigen_intra_op_thread_pool(). - ServiceExecutableRunOptions service_options( - run_options, backend_->StreamBorrower(), - backend_->eigen_intra_op_thread_pool()); + ServiceExecutableRunOptions service_options(run_options, + backend_->StreamBorrower()); if (executable_->dumping_snapshot()) { return ExecuteAndDump(&service_options, arguments); diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index ddb36680e8b185b053368baffa6f1d5cac50dc07..4f4fc8df31c633749ae9b6dafcdc38d4fd1eba40 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -114,7 +114,7 @@ class LocalClient : public Client { // Build and return a LocalExecutable object. The executable is compiled using // the given XlaComputation, argument layouts and options. // - // The given ExecutableBuildOptions override any values from TF_XLA_FLAGS + // The given ExecutableBuildOptions overrides any values from XLA_FLAGS // environment variable. StatusOr> Compile( const XlaComputation& computation, diff --git a/tensorflow/compiler/xla/client/sharding_builder.cc b/tensorflow/compiler/xla/client/sharding_builder.cc index fb9ea6ec3fc41d5e04ca125798a8199350470a44..b9bff06cbdbc3525eb19d5df885952c3971d9d6a 100644 --- a/tensorflow/compiler/xla/client/sharding_builder.cc +++ b/tensorflow/compiler/xla/client/sharding_builder.cc @@ -50,7 +50,7 @@ OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) { OpSharding result; result.set_type(OpSharding::Type::OpSharding_Type_OTHER); - CHECK_EQ(ShapeUtil::Rank(tile_shape), 1); + CHECK_EQ(tile_shape.rank(), 1); std::vector dimensions(1, num_tiles); *result.mutable_tile_shape() = tile_shape.ToProto(); auto& tile_dimension = diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 60df2ec3959216b0564846ad47c21c5bcc01ea57..16381155c3f875dcd55853ebbe004ae58af1590d 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -29,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -192,9 +195,9 @@ StatusOr XlaBuilder::GetProgramShape(XlaOp root) const { } void XlaBuilder::IsConstantVisitor(const int64 op_handle, - std::set* visited, + absl::flat_hash_set* visited, bool* is_constant) const { - if (visited->count(op_handle) != 0 || !*is_constant) { + if (visited->contains(op_handle) || !*is_constant) { return; } @@ -208,11 +211,21 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, } // TODO(b/32495713): We aren't checking the called computations. break; + case HloOpcode::kGetDimensionSize: { + int64 dimension_number = instr.dimensions(0); + const HloInstructionProto& operand = + *(LookUpInstructionByHandle(instr.operand_ids(0)).ValueOrDie()); + Shape operand_shape(operand.shape()); + if (operand_shape.is_dynamic_dimension(dimension_number)) { + *is_constant = false; + } + break; + } // Non functional ops. case HloOpcode::kRng: - case HloOpcode::kCrossReplicaSum: - // TODO(b/33009255): Implmement constant folding for cross replica sum. + case HloOpcode::kAllReduce: + // TODO(b/33009255): Implement constant folding for cross replica sum. case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kCall: @@ -244,6 +257,29 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num, int64 target_param_num, ShapeIndex target_param_index, int64 target_dim_num) { + bool param_exists = false; + for (HloInstructionProto& instr : instructions_) { + if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) && + instr.parameter_number() == target_param_num) { + param_exists = true; + Shape param_shape(instr.shape()); + Shape* param_shape_ptr = ¶m_shape; + for (int64 index : target_param_index) { + param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index); + } + param_shape_ptr->set_dynamic_dimension(target_dim_num, + /*is_dynamic=*/true); + *instr.mutable_shape() = param_shape.ToProto(); + } + } + + if (!param_exists) { + return InvalidArgument( + "Asked to mark parameter %lld as dynamic sized parameter, but the " + "doesn't exists", + target_param_num); + } + TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind( DynamicParameterBinding::DynamicParameter{dynamic_size_param_num, dynamic_size_param_index}, @@ -263,27 +299,51 @@ XlaComputation XlaBuilder::BuildAndNoteError() { return build_status.ConsumeValueOrDie(); } -StatusOr XlaBuilder::Build() { +Status XlaBuilder::GetCurrentStatus() const { if (!first_error_.ok()) { string backtrace; first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); return AppendStatus(first_error_, backtrace); } - return Build(instructions_.back().id()); + return Status::OK(); +} + +StatusOr XlaBuilder::Build(bool remove_dynamic_dimensions) { + TF_RETURN_IF_ERROR(GetCurrentStatus()); + return Build(instructions_.back().id(), remove_dynamic_dimensions); } -StatusOr XlaBuilder::Build(XlaOp root) { +StatusOr XlaBuilder::Build(XlaOp root, + bool remove_dynamic_dimensions) { if (root.builder_ != this) { return InvalidArgument("Given root operation is not in this computation."); } - return Build(root.handle()); -} + return Build(root.handle(), remove_dynamic_dimensions); +} + +StatusOr XlaBuilder::Build(int64 root_id, + bool remove_dynamic_dimensions) { + TF_RETURN_IF_ERROR(GetCurrentStatus()); + + // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove + // all dynamic dimensions before building xla program until we have support in + // the backend. + if (remove_dynamic_dimensions) { + std::function remove_dynamic_dimension = + [&](ShapeProto* shape) { + if (shape->tuple_shapes_size() != 0) { + for (int64 i = 0; i < shape->tuple_shapes_size(); ++i) { + remove_dynamic_dimension(shape->mutable_tuple_shapes(i)); + } + } + for (int64 i = 0; i < shape->dimensions_size(); ++i) { + shape->set_is_dynamic_dimension(i, false); + } + }; -StatusOr XlaBuilder::Build(int64 root_id) { - if (!first_error_.ok()) { - string backtrace; - first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); - return AppendStatus(first_error_, backtrace); + for (auto& instruction : instructions_) { + remove_dynamic_dimension(instruction.mutable_shape()); + } } HloComputationProto entry; @@ -310,7 +370,10 @@ StatusOr XlaBuilder::Build(int64 root_id) { module->add_computations()->Swap(&e.second); } module->add_computations()->Swap(&entry); - + if (!input_output_aliases_.empty()) { + TF_RETURN_IF_ERROR( + PopulateInputOutputAlias(module, program_shape, input_output_aliases_)); + } *(module->mutable_dynamic_parameter_binding()) = dynamic_parameter_binding_.ToProto(); @@ -323,6 +386,35 @@ StatusOr XlaBuilder::Build(int64 root_id) { return std::move(computation); } +/* static */ Status XlaBuilder::PopulateInputOutputAlias( + HloModuleProto* module, const ProgramShape& program_shape, + const std::vector& input_output_aliases) { + HloInputOutputAliasConfig config(program_shape.result()); + for (auto& alias : input_output_aliases) { + // The HloInputOutputAliasConfig does not do parameter validation as it only + // carries the result shape. Maybe it should be constructed with a + // ProgramShape to allow full validation. We will still get an error when + // trying to compile the HLO module, but would be better to have validation + // at this stage. + if (alias.param_number >= program_shape.parameters_size()) { + return InvalidArgument("Invalid parameter number %ld (total %ld)", + alias.param_number, + program_shape.parameters_size()); + } + const Shape& parameter_shape = program_shape.parameters(alias.param_number); + if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) { + return InvalidArgument("Invalid parameter %ld index: %s", + alias.param_number, + alias.param_index.ToString().c_str()); + } + TF_RETURN_IF_ERROR(config.SetUpAlias( + alias.output_index, alias.param_number, alias.param_index, + HloInputOutputAliasConfig::AliasKind::kUserAlias)); + } + *module->mutable_input_output_alias() = config.ToProto(); + return Status::OK(); +} + StatusOr XlaBuilder::InDimBroadcast( const Shape& shape, const XlaOp& operand, absl::Span broadcast_dimensions) { @@ -343,7 +435,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); CHECK(ShapeUtil::IsScalar(operand_shape) || - ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)); + operand_shape.rank() == output_shape.rank()); Shape broadcast_shape = ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type()); @@ -355,7 +447,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, // Do explicit broadcast for degenerate broadcast. std::vector broadcast_dimensions; std::vector reshaped_dimensions; - for (int i = 0; i < ShapeUtil::Rank(operand_shape); i++) { + for (int i = 0; i < operand_shape.rank(); i++) { if (operand_shape.dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand_shape.dimensions(i)); @@ -398,8 +490,8 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, binop, lhs_shape, rhs_shape, broadcast_dimensions)); *instr.mutable_shape() = shape.ToProto(); - const int64 lhs_rank = ShapeUtil::Rank(lhs_shape); - const int64 rhs_rank = ShapeUtil::Rank(rhs_shape); + const int64 lhs_rank = lhs_shape.rank(); + const int64 rhs_rank = rhs_shape.rank(); XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; @@ -410,17 +502,19 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape; std::vector to_size; - for (int64 size : shape.dimensions()) { - to_size.push_back(size); + std::vector to_size_is_dynamic; + for (int i = 0; i < shape.rank(); i++) { + to_size.push_back(shape.dimensions(i)); + to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i)); } - for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape); - from_dim++) { + for (int64 from_dim = 0; from_dim < from_shape.rank(); from_dim++) { int64 to_dim = broadcast_dimensions[from_dim]; to_size[to_dim] = from_shape.dimensions(from_dim); + to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim); } - const Shape& broadcasted_shape = - ShapeUtil::MakeShape(from_shape.element_type(), to_size); + const Shape& broadcasted_shape = ShapeUtil::MakeShape( + from_shape.element_type(), to_size, to_size_is_dynamic); TF_ASSIGN_OR_RETURN( XlaOp broadcasted_operand, InDimBroadcast(broadcasted_shape, from, broadcast_dimensions)); @@ -458,18 +552,18 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; XlaOp updated_ehs = ehs; - if (!ShapeUtil::IsTuple(shape)) { - if (!ShapeUtil::IsTuple(lhs_shape) && + if (!shape.IsTuple()) { + if (!lhs_shape.IsTuple() && !ShapeUtil::SameDimensions(shape, lhs_shape)) { // lhs is being implicitly broadcasted. Change to explicit. TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(shape, lhs)); } - if (!ShapeUtil::IsTuple(rhs_shape) && + if (!rhs_shape.IsTuple() && !ShapeUtil::SameDimensions(shape, rhs_shape)) { // rhs is being implicitly broadcasted. Change to explicit. TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(shape, rhs)); } - if (!ShapeUtil::IsTuple(ehs_shape) && + if (!ehs_shape.IsTuple() && !ShapeUtil::SameDimensions(shape, ehs_shape)) { // ehs is being implicitly broadcasted. Change to explicit. TF_ASSIGN_OR_RETURN(updated_ehs, AddBroadcastSequence(shape, ehs)); @@ -480,16 +574,6 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, }); } -XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); -} - XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -563,10 +647,10 @@ XlaOp XlaBuilder::Broadcast(const XlaOp& operand, // output, so to append dimensions on the left the instruction's dimensions // should just be the n highest dimension numbers of the output shape where // n is the number of input dimensions. - const int64 operand_rank = ShapeUtil::Rank(operand_shape); + const int64 operand_rank = operand_shape.rank(); std::vector dimensions(operand_rank); for (int i = 0; i < operand_rank; ++i) { - dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank; + dimensions[i] = i + shape.rank() - operand_rank; } return InDimBroadcast(shape, operand, dimensions); }); @@ -579,8 +663,17 @@ XlaOp XlaBuilder::BroadcastInDim( TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); // Output shape, in the case of degenerate broadcast, the out_dim_size is // not necessarily the same as the dimension sizes of the output shape. - const auto& output_shape = + auto output_shape = ShapeUtil::MakeShape(operand_shape.element_type(), out_dim_size); + for (int i = 0; i < broadcast_dimensions.size(); i++) { + if (broadcast_dimensions[i] < 0 || + broadcast_dimensions[i] > out_dim_size.size()) { + return InvalidArgument("Broadcast dimension %lld is out of bound", + broadcast_dimensions[i]); + } + output_shape.set_dynamic_dimension(broadcast_dimensions[i], + operand_shape.is_dynamic_dimension(i)); + } TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape( operand_shape, output_shape, broadcast_dimensions) @@ -639,10 +732,10 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); - std::vector starts(ShapeUtil::Rank(shape), 0); + std::vector starts(shape.rank(), 0); std::vector limits(shape.dimensions().begin(), shape.dimensions().end()); - std::vector strides(ShapeUtil::Rank(shape), 1); + std::vector strides(shape.rank(), 1); starts[dimno] = start_index; limits[dimno] = limit_index; strides[dimno] = stride; @@ -660,7 +753,7 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, GetShape(start_indices)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicSliceShape( - operand_shape, start_indices_shape, slice_sizes)); + operand_shape, {start_indices_shape}, slice_sizes)); *instr.mutable_shape() = shape.ToProto(); for (int64 size : slice_sizes) { @@ -672,6 +765,34 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, }); } +XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, + absl::Span start_indices, + absl::Span slice_sizes) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + std::vector start_indices_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes, + GetOperandShapes(start_indices)); + absl::c_transform(start_indices_shapes, + std::back_inserter(start_indices_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferDynamicSliceShape( + operand_shape, start_indices_shapes, slice_sizes)); + *instr.mutable_shape() = shape.ToProto(); + + for (int64 size : slice_sizes) { + instr.add_dynamic_slice_sizes(size); + } + + std::vector operands = {operand}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); + }); +} + XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -681,13 +802,38 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, GetShape(start_indices)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferDynamicUpdateSliceShape( + operand_shape, update_shape, {start_indices_shape})); + *instr.mutable_shape() = shape.ToProto(); + + return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, + {operand, update, start_indices}); + }); +} + +XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); + std::vector start_indices_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes, + GetOperandShapes(start_indices)); + absl::c_transform(start_indices_shapes, + std::back_inserter(start_indices_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicUpdateSliceShape( - operand_shape, update_shape, start_indices_shape)); + operand_shape, update_shape, start_indices_shapes)); *instr.mutable_shape() = shape.ToProto(); + std::vector operands = {operand, update}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, - {operand, update, start_indices}); + operands); }); } @@ -780,7 +926,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ","); std::vector new_sizes; - for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) { + for (int i = 0; i < original_shape.rank(); ++i) { if (i <= dimensions.front() || i > dimensions.back()) { new_sizes.push_back(original_shape.dimensions(i)); } else { @@ -808,10 +954,9 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& true_shape, GetShape(on_true)); TF_ASSIGN_OR_RETURN(const Shape& false_shape, GetShape(on_false)); - TF_RET_CHECK(ShapeUtil::IsTuple(true_shape) == - ShapeUtil::IsTuple(false_shape)); - HloOpcode opcode = ShapeUtil::IsTuple(true_shape) ? HloOpcode::kTupleSelect - : HloOpcode::kSelect; + TF_RET_CHECK(true_shape.IsTuple() == false_shape.IsTuple()); + HloOpcode opcode = + true_shape.IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect; return TernaryOp(opcode, pred, on_true, on_false); }); } @@ -835,7 +980,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data)); - if (!ShapeUtil::IsTuple(tuple_shape)) { + if (!tuple_shape.IsTuple()) { return InvalidArgument( "Operand to GetTupleElement() is not a tuple; got %s", ShapeUtil::HumanString(tuple_shape)); @@ -850,36 +995,6 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { }); } -XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions); -} - XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -900,6 +1015,18 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + // If one operand is a scalar, just multiply the two operands. + if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) { + if (dimension_numbers.rhs_batch_dimensions_size() != 0 || + dimension_numbers.lhs_batch_dimensions_size() != 0 || + dimension_numbers.rhs_contracting_dimensions_size() != 0 || + dimension_numbers.lhs_contracting_dimensions_size() != 0) { + return InvalidArgument( + "Dots with scalar operands must have no contracting or batch " + "dimensions"); + } + return xla::Mul(lhs, rhs); + } TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); @@ -915,13 +1042,13 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, Status XlaBuilder::VerifyConvolution( const Shape& lhs_shape, const Shape& rhs_shape, const ConvolutionDimensionNumbers& dimension_numbers) const { - if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) { + if (lhs_shape.rank() != rhs_shape.rank()) { return InvalidArgument( "Convolution arguments must have same number of " "dimensions. Got: %s and %s", ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape)); } - int num_dims = ShapeUtil::Rank(lhs_shape); + int num_dims = lhs_shape.rank(); if (num_dims < 2) { return InvalidArgument( "Convolution expects argument arrays with >= 3 dimensions. " @@ -959,27 +1086,29 @@ Status XlaBuilder::VerifyConvolution( XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config); + feature_group_count, batch_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return ConvGeneral(lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config); + feature_group_count, batch_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -1007,7 +1136,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( MakePadding(base_area_dimensions, window_dimensions, window_strides, padding), dimension_numbers, feature_group_count, - precision_config); + batch_group_count, precision_config); }); } @@ -1015,10 +1144,11 @@ XlaOp XlaBuilder::ConvGeneral( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, dimension_numbers, feature_group_count, - precision_config); + batch_group_count, precision_config); } XlaOp XlaBuilder::ConvGeneralDilated( @@ -1026,7 +1156,8 @@ XlaOp XlaBuilder::ConvGeneralDilated( absl::Span> padding, absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -1045,14 +1176,15 @@ XlaOp XlaBuilder::ConvGeneralDilated( MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, feature_group_count, - instr.window(), dimension_numbers)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, feature_group_count, + batch_group_count, instr.window(), dimension_numbers)); *instr.mutable_shape() = shape.ToProto(); *instr.mutable_convolution_dimension_numbers() = dimension_numbers; instr.set_feature_group_count(feature_group_count); + instr.set_batch_group_count(batch_group_count); if (precision_config != nullptr) { *instr.mutable_precision_config() = *precision_config; @@ -1145,7 +1277,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { *instr.mutable_shape() = infeed_instruction_shape.ToProto(); instr.set_infeed_config(config); - if (ShapeUtil::IsArray(shape) && sharding() && + if (shape.IsArray() && sharding() && sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) { // TODO(b/110793772): Support tiled array-shaped infeeds. return InvalidArgument( @@ -1221,7 +1353,7 @@ XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape, *instr.mutable_shape() = infeed_instruction_shape.ToProto(); instr.set_infeed_config(config); - if (ShapeUtil::IsArray(shape) && sharding() && + if (shape.IsArray() && sharding() && sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) { // TODO(b/110793772): Support tiled array-shaped infeeds. return InvalidArgument( @@ -1334,7 +1466,7 @@ XlaOp XlaBuilder::AfterAll(absl::Span tokens) { for (int i = 0; i < tokens.size(); ++i) { const XlaOp& operand = tokens[i]; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - if (!ShapeUtil::IsToken(operand_shape)) { + if (!operand_shape.IsToken()) { return InvalidArgument( "All operands to AfterAll must be tokens; operand %d has shape %s", i, ShapeUtil::HumanString(operand_shape)); @@ -1390,147 +1522,6 @@ XlaOp XlaBuilder::CustomCall( }); } -XlaOp XlaBuilder::Complex(const XlaOp& real, const XlaOp& imag, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions); -} - -XlaOp XlaBuilder::Conj(const XlaOp& operand) { - return Complex(Real(operand), Neg(Imag(operand))); -} - -XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Not(const XlaOp& operand) { - return UnaryOp(HloOpcode::kNot, operand); -} - -XlaOp XlaBuilder::ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, - broadcast_dimensions); -} - -XlaOp XlaBuilder::ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, - broadcast_dimensions); -} - -XlaOp XlaBuilder::Abs(const XlaOp& operand) { - return UnaryOp(HloOpcode::kAbs, operand); -} - -XlaOp XlaBuilder::Atan2(const XlaOp& y, const XlaOp& x, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions); -} - -XlaOp XlaBuilder::Exp(const XlaOp& operand) { - return UnaryOp(HloOpcode::kExp, operand); -} - -XlaOp XlaBuilder::Expm1(const XlaOp& operand) { - return UnaryOp(HloOpcode::kExpm1, operand); -} - -XlaOp XlaBuilder::Floor(const XlaOp& operand) { - return UnaryOp(HloOpcode::kFloor, operand); -} - -XlaOp XlaBuilder::Ceil(const XlaOp& operand) { - return UnaryOp(HloOpcode::kCeil, operand); -} - -XlaOp XlaBuilder::Round(const XlaOp& operand) { - return UnaryOp(HloOpcode::kRoundNearestAfz, operand); -} - -XlaOp XlaBuilder::Log(const XlaOp& operand) { - return UnaryOp(HloOpcode::kLog, operand); -} - -XlaOp XlaBuilder::Log1p(const XlaOp& operand) { - return UnaryOp(HloOpcode::kLog1p, operand); -} - -XlaOp XlaBuilder::Sign(const XlaOp& operand) { - return UnaryOp(HloOpcode::kSign, operand); -} - -XlaOp XlaBuilder::Clz(const XlaOp& operand) { - return UnaryOp(HloOpcode::kClz, operand); -} - -XlaOp XlaBuilder::Cos(const XlaOp& operand) { - return UnaryOp(HloOpcode::kCos, operand); -} - -XlaOp XlaBuilder::Sin(const XlaOp& operand) { - return UnaryOp(HloOpcode::kSin, operand); -} - -XlaOp XlaBuilder::Tanh(const XlaOp& operand) { - return UnaryOp(HloOpcode::kTanh, operand); -} - -XlaOp XlaBuilder::Real(const XlaOp& operand) { - return UnaryOp(HloOpcode::kReal, operand); -} - -XlaOp XlaBuilder::Imag(const XlaOp& operand) { - return UnaryOp(HloOpcode::kImag, operand); -} - -XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { - return UnaryOp(HloOpcode::kIsFinite, operand); -} - XlaOp XlaBuilder::Transpose(const XlaOp& operand, absl::Span permutation) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1561,36 +1552,146 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } +namespace { +// Switch from a floating point value to a integer value in such a way that when +// using the integer value to compare, we get the same result for normal values, +// and -Nan is treated as the smallest value, and Nan is treated as the largest +// value. +// If f is a float, and +// x = bit_cast(f); +// y = x < 0 ? numeric_limits::max() - x : x; +// then y is ordered as an int32 such that finite values have the obvious order, +// -0 is ordered before 0, and -NaN and NaN appear at the beginning and end of +// the ordering. +// Note that in order to avoid -x to overflow, we calculate +// numeric_limits::max() - x as unsigned, and then convert back to +// signed. +XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value, + int64 bit_width) { + PrimitiveType signed_type; + PrimitiveType unsigned_type; + XlaOp max_value; + switch (bit_width) { + case 16: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S16; + unsigned_type = U16; + break; + case 32: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S32; + unsigned_type = U32; + break; + case 64: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S64; + unsigned_type = U64; + break; + default: + return value.builder()->ReportError( + InvalidArgument("Invalid bit width %lld for Comparator floating " + "point parameter.", + bit_width)); + } + auto signed_value = BitcastConvertType(value, signed_type); + auto unsigned_value = BitcastConvertType(value, unsigned_type); + auto flipped_value = + BitcastConvertType(Sub(max_value, unsigned_value), signed_type); + auto is_negative = + Lt(signed_value, + ConstantLiteral(value.builder(), LiteralUtil::Zero(signed_type))); + return Select(is_negative, flipped_value, signed_value); +} +} // namespace + XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span values, int64 dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + std::vector operands{keys}; + for (const XlaOp& value : values) { + operands.push_back(value); + } + // Build the default less-than comparator (copied from lib/comparators.cc). + // TODO(b/122298745): Remove the deprecated API method so that this code + // duplication can be deleted. + auto b = this->CreateSubBuilder("comparator"); + std::vector operand_types; + for (const XlaOp& operand : operands) { + TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(operand)); + operand_types.push_back(operand_shape.element_type()); + } + + int64 parameter_count = 0; + XlaOp first_lhs_param; + XlaOp first_rhs_param; + + for (auto operand_type : operand_types) { + auto scalar_shape = ShapeUtil::MakeShape(operand_type, {}); + auto lhs_param = + b->Parameter(parameter_count * 2, scalar_shape, + absl::StrCat("p.", parameter_count, ".lhs")); + auto rhs_param = + b->Parameter(parameter_count * 2 + 1, scalar_shape, + absl::StrCat("p.", parameter_count, ".rhs")); + if (parameter_count == 0) { + first_lhs_param = lhs_param; + first_rhs_param = rhs_param; + } + ++parameter_count; + } + if (primitive_util::IsFloatingPointType(operand_types[0])) { + PrimitiveType compare_type = operand_types[0]; + // Special-case handling for BF16. We currently do not support direct + // comparisons with BF16, so we convert to F32 and then use the F32 + // comparison logic. + if (compare_type == BF16) { + compare_type = F32; + first_lhs_param = b->ConvertElementType(first_lhs_param, F32); + first_rhs_param = b->ConvertElementType(first_rhs_param, F32); + } + int64 bit_width = primitive_util::BitWidth(compare_type); + first_lhs_param = + BitcastConvertFloatingPointToIntegral(first_lhs_param, bit_width); + first_rhs_param = + BitcastConvertFloatingPointToIntegral(first_rhs_param, bit_width); + } + Lt(first_lhs_param, first_rhs_param); + + TF_ASSIGN_OR_RETURN(auto comparator, b->Build()); + return Sort(operands, comparator, dimension, /*is_stable=*/false); + }); +} + +XlaOp XlaBuilder::Sort(absl::Span operands, + const XlaComputation& comparator, int64 dimension, + bool is_stable) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; + instr.set_is_stable(is_stable); std::vector operand_shape_ptrs; - TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); - operand_shape_ptrs.push_back(&keys_shape); - TF_ASSIGN_OR_RETURN(std::vector values_shapes, - GetOperandShapes(values)); - absl::c_transform(values_shapes, std::back_inserter(operand_shape_ptrs), + TF_ASSIGN_OR_RETURN(std::vector operand_shapes, + GetOperandShapes(operands)); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape( HloOpcode::kSort, operand_shape_ptrs)); *instr.mutable_shape() = shape.ToProto(); if (dimension == -1) { - TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); - dimension = ShapeUtil::Rank(keys_shape) - 1; + TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(operands[0])); + dimension = keys_shape.rank() - 1; } instr.add_dimensions(dimension); - std::vector operands{keys}; - operands.insert(operands.end(), values.begin(), values.end()); + AddCalledComputation(comparator, &instr); return AddInstruction(std::move(instr), HloOpcode::kSort, operands); }); } -XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions); -} - XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1616,10 +1717,6 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, }); } -XlaOp XlaBuilder::Neg(const XlaOp& operand) { - return UnaryOp(HloOpcode::kNegate, operand); -} - XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { return TernaryOp(HloOpcode::kClamp, min, operand, max); @@ -1647,12 +1744,12 @@ XlaOp XlaBuilder::Map(absl::Span operands, *instr.mutable_shape() = shape.ToProto(); Shape output_shape(instr.shape()); - const int64 output_rank = ShapeUtil::Rank(output_shape); + const int64 output_rank = output_shape.rank(); AddCalledComputation(computation, &instr); std::vector new_operands(operands.begin(), operands.end()); for (XlaOp& new_operand : new_operands) { TF_ASSIGN_OR_RETURN(Shape shape, GetShape(new_operand)); - const int64 rank = ShapeUtil::Rank(shape); + const int64 rank = shape.rank(); if (rank != output_rank) { TF_ASSIGN_OR_RETURN(new_operand, InDimBroadcast(output_shape, new_operand, {})); @@ -1861,7 +1958,7 @@ XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - std::vector all_dimnos(ShapeUtil::Rank(operand_shape)); + std::vector all_dimnos(operand_shape.rank()); std::iota(all_dimnos.begin(), all_dimnos.end(), 0); return Reduce(operand, init_value, computation, all_dimnos); }); @@ -2000,8 +2097,8 @@ XlaOp XlaBuilder::CrossReplicaSum( TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); auto b = CreateSubBuilder("sum"); - b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), - b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); + Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), + b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); TF_ASSIGN_OR_RETURN(auto computation, b->Build()); return CrossReplicaSum(operand, computation, replica_groups, /*channel_id=*/absl::nullopt); @@ -2015,8 +2112,8 @@ XlaOp XlaBuilder::CrossReplicaSum( return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( - {&operand_shape})); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferAllReduceShape({&operand_shape})); *instr.mutable_shape() = shape.ToProto(); for (const ReplicaGroup& group : replica_groups) { @@ -2029,8 +2126,7 @@ XlaOp XlaBuilder::CrossReplicaSum( AddCalledComputation(computation, &instr); - return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, - {operand}); + return AddInstruction(std::move(instr), HloOpcode::kAllReduce, {operand}); }); } @@ -2111,6 +2207,14 @@ XlaOp XlaBuilder::CollectivePermute( }); } +XlaOp XlaBuilder::ReplicaId() { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {}); + }); +} + XlaOp XlaBuilder::SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, @@ -2288,7 +2392,7 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, ShapeUtil::HumanStringWithLayout(operand_shape)); } // TODO(b/111544877): Support tuple shapes. - if (!ShapeUtil::IsArray(operand_shape)) { + if (!operand_shape.IsArray()) { return InvalidArgument("SendToHost only supports array shapes, shape: %s", ShapeUtil::HumanString(operand_shape)); } @@ -2328,7 +2432,7 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, } // TODO(b/111544877): Support tuple shapes. - if (!ShapeUtil::IsArray(shape)) { + if (!shape.IsArray()) { return InvalidArgument( "RecvFromHost only supports array shapes, shape: %s", ShapeUtil::HumanString(shape)); @@ -2381,7 +2485,7 @@ StatusOr XlaBuilder::IsConstant(const XlaOp& operand) const { TF_RETURN_IF_ERROR(LookUpInstruction(operand).status()); bool is_constant = true; - std::set visited; + absl::flat_hash_set visited; IsConstantVisitor(operand.handle(), &visited, &is_constant); return is_constant; } @@ -2428,21 +2532,58 @@ StatusOr XlaBuilder::BuildConstantSubGraph( worklist.pop(); TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto, LookUpInstructionByHandle(handle)); - for (int64 id : instr_proto->operand_ids()) { - if (related_ops.insert(id).second) { - worklist.push(id); + + if (instr_proto->opcode() == + HloOpcodeString(HloOpcode::kGetDimensionSize)) { + // At this point, BuildConstantSubGraph should never encounter a + // GetDimensionSize with a dynamic dimension. IsConstant check would have + // failed at the beginning of this function. + // + // Replace GetDimensionSize with a Constant representing the static bound + // of the shape. + int64 dimension = instr_proto->dimensions(0); + int64 operand_handle = instr_proto->operand_ids(0); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, + LookUpInstructionByHandle(operand_handle)); + + TF_RET_CHECK(!operand_proto->shape().is_dynamic_dimension(dimension)); + auto constant_dimension_size = + static_cast(operand_proto->shape().dimensions(dimension)); + + Literal literal = LiteralUtil::CreateR0(constant_dimension_size); + + HloInstructionProto const_instr; + *const_instr.mutable_shape() = literal.shape().ToProto(); + *const_instr.mutable_literal() = literal.ToProto(); + *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant); + + const_instr.set_id(handle); + *const_instr.mutable_name() = + GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id()); + *entry.add_instructions() = + const_instr; // Add to the result constant graph. + } else { + for (int64 id : instr_proto->operand_ids()) { + if (related_ops.insert(id).second) { + worklist.push(id); + } + } + for (int64 called_id : instr_proto->called_computation_ids()) { + related_calls.insert(called_id); } - } - for (int64 called_id : instr_proto->called_computation_ids()) { - related_calls.insert(called_id); } } // Add related ops to the computation. for (int64 id : related_ops) { - auto* instr = entry.add_instructions(); TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src, LookUpInstructionByHandle(id)); + + if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) { + continue; + } + auto* instr = entry.add_instructions(); + *instr = *instr_src; // Ensures that the instruction names are unique among the graph. const string& new_name = @@ -2715,12 +2856,21 @@ XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes) { return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); } +XlaOp DynamicSlice(const XlaOp& operand, absl::Span start_indices, + absl::Span slice_sizes) { + return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); +} XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices) { return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); } +XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices) { + return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); +} + XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, int64 dimension) { return builder->ConcatInDim(operands, dimension); @@ -2744,32 +2894,38 @@ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) { XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kEq, lhs, rhs, + broadcast_dimensions); } XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kNe, lhs, rhs, + broadcast_dimensions); } XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kGe, lhs, rhs, + broadcast_dimensions); } XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kGt, lhs, rhs, + broadcast_dimensions); } -XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kLe, lhs, rhs, + broadcast_dimensions); } -XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Le(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kLt, lhs, rhs, + broadcast_dimensions); } XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, @@ -2786,38 +2942,42 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return lhs.builder()->Conv(lhs, rhs, window_strides, padding, - feature_group_count, precision_config); + feature_group_count, batch_group_count, + precision_config); } XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->ConvWithGeneralPadding( - lhs, rhs, window_strides, padding, feature_group_count, precision_config); + lhs, rhs, window_strides, padding, feature_group_count, batch_group_count, + precision_config); } XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return lhs.builder()->ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config); + batch_group_count, precision_config); } XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config); + batch_group_count, precision_config); } XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, @@ -2826,11 +2986,12 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneralDilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers, feature_group_count, precision_config); + dimension_numbers, feature_group_count, batch_group_count, + precision_config); } XlaOp Fft(const XlaOp& operand, FftType fft_type, @@ -2838,6 +2999,29 @@ XlaOp Fft(const XlaOp& operand, FftType fft_type, return operand.builder()->Fft(operand, fft_type, fft_length); } +XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool unit_diagonal, + TriangularSolveOptions::Transpose transpose_a) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(const Shape& b_shape, builder->GetShape(b)); + xla::TriangularSolveOptions& options = + *instr.mutable_triangular_solve_options(); + options.set_left_side(left_side); + options.set_lower(lower); + options.set_unit_diagonal(unit_diagonal); + options.set_transpose_a(transpose_a); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape( + a_shape, b_shape, options)); + *instr.mutable_shape() = shape.ToProto(); + + return builder->AddInstruction(std::move(instr), + HloOpcode::kTriangularSolve, {a, b}); + }); +} + XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) { return builder->Infeed(shape, config); } @@ -2867,78 +3051,96 @@ XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, operand_shapes_with_layout); } -XlaOp Complex(const XlaOp& real, const XlaOp& imag, +XlaOp Complex(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return real.builder()->Complex(real, imag, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs, + broadcast_dimensions); } -XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); } +XlaOp Conj(const XlaOp& operand) { + return Complex(Real(operand), Neg(Imag(operand))); +} XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Add(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs, + broadcast_dimensions); } XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs, + broadcast_dimensions); } XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs, + broadcast_dimensions); } XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Div(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs, + broadcast_dimensions); } XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs, + broadcast_dimensions); } XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Max(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs, + broadcast_dimensions); } XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Min(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs, + broadcast_dimensions); } XlaOp And(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->And(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs, + broadcast_dimensions); } XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Or(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs, + broadcast_dimensions); } XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs, + broadcast_dimensions); } -XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); } +XlaOp Not(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kNot, operand); +} XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, + broadcast_dimensions); } XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, + broadcast_dimensions); } XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, + broadcast_dimensions); } XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, @@ -3010,6 +3212,8 @@ XlaOp CollectivePermute( return operand.builder()->CollectivePermute(operand, source_target_pairs); } +XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); } + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, Padding padding, @@ -3031,48 +3235,73 @@ XlaOp SelectAndScatterWithGeneralPadding( init_value, scatter); } -XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); } +XlaOp Abs(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kAbs, operand); +} -XlaOp Atan2(const XlaOp& y, const XlaOp& x, +XlaOp Atan2(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return y.builder()->Atan2(y, x, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs, + broadcast_dimensions); } -XlaOp Exp(const XlaOp& operand) { return operand.builder()->Exp(operand); } - -XlaOp Expm1(const XlaOp& operand) { return operand.builder()->Expm1(operand); } - -XlaOp Floor(const XlaOp& operand) { return operand.builder()->Floor(operand); } - -XlaOp Ceil(const XlaOp& operand) { return operand.builder()->Ceil(operand); } - -XlaOp Round(const XlaOp& operand) { return operand.builder()->Round(operand); } - -XlaOp Log(const XlaOp& operand) { return operand.builder()->Log(operand); } - -XlaOp Log1p(const XlaOp& operand) { return operand.builder()->Log1p(operand); } - -XlaOp Sign(const XlaOp& operand) { return operand.builder()->Sign(operand); } - -XlaOp Clz(const XlaOp& operand) { return operand.builder()->Clz(operand); } - -XlaOp Cos(const XlaOp& operand) { return operand.builder()->Cos(operand); } - -XlaOp Sin(const XlaOp& operand) { return operand.builder()->Sin(operand); } - -XlaOp Tanh(const XlaOp& operand) { return operand.builder()->Tanh(operand); } - -XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); } - -XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); } +XlaOp Exp(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kExp, operand); +} +XlaOp Expm1(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand); +} +XlaOp Floor(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kFloor, operand); +} +XlaOp Ceil(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kCeil, operand); +} +XlaOp Round(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand); +} +XlaOp Log(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kLog, operand); +} +XlaOp Log1p(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand); +} +XlaOp Sign(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kSign, operand); +} +XlaOp Clz(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kClz, operand); +} +XlaOp Cos(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kCos, operand); +} +XlaOp Sin(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kSin, operand); +} +XlaOp Tanh(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kTanh, operand); +} +XlaOp Real(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kReal, operand); +} +XlaOp Imag(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kImag, operand); +} +XlaOp Sqrt(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand); +} +XlaOp Rsqrt(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand); +} XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs, + broadcast_dimensions); } XlaOp IsFinite(const XlaOp& operand) { - return operand.builder()->IsFinite(operand); + return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand); } XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { @@ -3083,7 +3312,9 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { return operand.builder()->BitcastConvertType(operand, new_element_type); } -XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); } +XlaOp Neg(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kNegate, operand); +} XlaOp Transpose(const XlaOp& operand, absl::Span permutation) { return operand.builder()->Transpose(operand, permutation); @@ -3097,6 +3328,12 @@ XlaOp Sort(const XlaOp& keys, absl::Span values, int64 dimension) { return keys.builder()->Sort(keys, values, dimension); } +XlaOp Sort(absl::Span operands, const XlaComputation& comparator, + int64 dimension, bool is_stable) { + return operands[0].builder()->Sort(operands, comparator, dimension, + is_stable); +} + XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { return min.builder()->Clamp(min, operand, max); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 098efb60f9bdca8306ff771a505f4a225dea9f7d..129e51674293fe7decd041ed05641519a8e8e444 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -56,6 +56,9 @@ class XlaOp { } ~XlaOp() = default; + XlaOp(const XlaOp& other) = default; + XlaOp& operator=(const XlaOp& other) = default; + // Precondition: !IsUninitialized(). // // It's very common to do foo.builder()->bar(). Without this precondition, if @@ -197,11 +200,19 @@ class XlaBuilder { // status. Note that all ops that have been enqueued will be moved to the // computation being returned. The root of the computation will be the last // added operation. - StatusOr Build(); + // + // `remove_dynamic_dimensions` tells the builder whether to remove the + // dyanmic dimensions information in all ops. + // + // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the + // dynamic dimensions information when XLA backend can handle dynamic + // dimensions. + StatusOr Build(bool remove_dynamic_dimensions = true); // Overload of Build which specifies a particular root instruction for the // computation. - StatusOr Build(XlaOp root); + StatusOr Build(XlaOp root, + bool remove_dynamic_dimensions = true); // Builds the computation with the requested operations, or notes an error in // the parent XlaBuilder and returns an empty computation if building failed. @@ -227,6 +238,10 @@ class XlaBuilder { // See also set_die_immediately_on_error(). Status first_error() const { return first_error_; } + // Returns the current status of the builder, complete with the stack trace + // information. + Status GetCurrentStatus() const; + // Returns the shape of the given op. StatusOr GetShape(const XlaOp& op) const; @@ -269,6 +284,10 @@ class XlaBuilder { // and its real dynamic size is represented by `dynamic_param_index` in // parameter `dynamic_param_num`. // + // Note that this should be called before the dynamic parameters are used to + // create other operations, otherwise created operations won't have the + // dynamic dimensions information. + // // TODO(b/119520625): Remove this API once we have more dynamic shape infra // ready. Status SetDynamicBinding(int64 dynamic_size_param_num, @@ -276,9 +295,24 @@ class XlaBuilder { int64 target_param_num, ShapeIndex target_param_index, int64 target_dim_num); + // Adds a new input/output alias. Since the input/ouput shape information are + // not available until the computation is built, and eventual error in the + // arguments of this API will be detected only at computation Build() time. + void SetUpAlias(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + input_output_aliases_.push_back({output_index, param_number, param_index}); + } + private: + // Describes an input/output alias as inserted by the SetUpAlias() API. + struct InputOutputAlias { + ShapeIndex output_index; + int64 param_number; + ShapeIndex param_index; + }; + // Build helper which takes the id of the root operation.. - StatusOr Build(int64 root_id); + StatusOr Build(int64 root_id, bool remove_dynamic_dimensions); // Description for the methods below can be found in the corresponding public // functions section in this file. @@ -288,38 +322,6 @@ class XlaBuilder { XlaOp ConstantLiteral(const LiteralSlice& literal); - template - XlaOp ConstantR0(NativeT value); - template - XlaOp ConstantR1(absl::Span values); - XlaOp ConstantR1(const tensorflow::core::Bitmap& values); - template - XlaOp ConstantR2( - std::initializer_list> values); - template - XlaOp ConstantFromArrayWithLayout(const Array& values, - const Layout& layout); - template - XlaOp ConstantFromArray(const Array& values); - template - XlaOp ConstantR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout); - template - XlaOp ConstantR2FromArray2D(const Array2D& values); - template - XlaOp ConstantR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout); - template - XlaOp ConstantR3FromArray3D(const Array3D& values); - template - XlaOp ConstantR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout); - template - XlaOp ConstantR4FromArray4D(const Array4D& values); - - template - XlaOp ConstantR1(int64 length, NativeT value); - XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); @@ -344,11 +346,18 @@ class XlaBuilder { XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); + ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); + XlaOp DynamicSlice(const XlaOp& operand, + absl::Span start_indices, + absl::Span slice_sizes); + ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); + XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices); XlaOp ConcatInDim(absl::Span operands, int64 dimension); @@ -360,24 +369,6 @@ class XlaBuilder { XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); - XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfig* precision_config = nullptr); @@ -387,28 +378,28 @@ class XlaBuilder { XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, @@ -418,6 +409,7 @@ class XlaBuilder { absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, + int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp Fft(const XlaOp& operand, FftType fft_type, @@ -441,50 +433,6 @@ class XlaBuilder { const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout); - XlaOp Complex(const XlaOp& real, const XlaOp& imag, - absl::Span broadcast_dimensions = {}); - - XlaOp Conj(const XlaOp& operand); - - XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Not(const XlaOp& operand); - - XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, absl::Span dimensions_to_reduce); @@ -527,6 +475,8 @@ class XlaBuilder { const XlaOp& operand, const std::vector>& source_target_pairs); + XlaOp ReplicaId(); + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, @@ -541,44 +491,6 @@ class XlaBuilder { absl::Span> padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter); - XlaOp Abs(const XlaOp& operand); - - XlaOp Atan2(const XlaOp& y, const XlaOp& x, - absl::Span broadcast_dimensions = {}); - - XlaOp Exp(const XlaOp& operand); - - XlaOp Expm1(const XlaOp& operand); - - XlaOp Floor(const XlaOp& operand); - - XlaOp Ceil(const XlaOp& operand); - - XlaOp Round(const XlaOp& operand); - - XlaOp Log(const XlaOp& operand); - - XlaOp Log1p(const XlaOp& operand); - - XlaOp Sign(const XlaOp& operand); - - XlaOp Clz(const XlaOp& operand); - - XlaOp Cos(const XlaOp& operand); - - XlaOp Sin(const XlaOp& operand); - - XlaOp Tanh(const XlaOp& operand); - - XlaOp Real(const XlaOp& operand); - - XlaOp Imag(const XlaOp& operand); - - XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp IsFinite(const XlaOp& operand); - XlaOp Iota(const Shape& shape, int64 iota_dimension); XlaOp Iota(PrimitiveType type, int64 size); @@ -589,14 +501,15 @@ class XlaBuilder { XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); - XlaOp Neg(const XlaOp& operand); - XlaOp Transpose(const XlaOp& operand, absl::Span permutation); XlaOp Rev(const XlaOp& operand, absl::Span dimensions); + ABSL_DEPRECATED("Use form with comparator computation instead") XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); + XlaOp Sort(absl::Span operands, const XlaComputation& comparator, + int64 dimension = -1, bool is_stable = false); XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -711,7 +624,8 @@ class XlaBuilder { // operation such as `RngNormal` or `Infeed`. The visitor walks the // computation starting at a given operation and sets is_constant to false iff // a parameter or stateful operation is encountered. - void IsConstantVisitor(const int64 op_handle, std::set* visited, + void IsConstantVisitor(const int64 op_handle, + absl::flat_hash_set* visited, bool* is_constant) const; // Checks bounds for convolution parameters. @@ -729,6 +643,12 @@ class XlaBuilder { int64 GetNextId() { return ++next_id_; } + // Populates the module with the input/output alias information stored within + // the input_output_aliases vector. + static Status PopulateInputOutputAlias( + HloModuleProto* module, const ProgramShape& program_shape, + const std::vector& input_output_aliases); + string name_; // Name to use for the built computation. // The next sequential ID for every instruction/computation contained within @@ -748,6 +668,9 @@ class XlaBuilder { // Dynamic parameter configuration of this computation. DynamicParameterBinding dynamic_parameter_binding_; + // Holds the input/output alias information populated by the SetUpAlias() API. + std::vector input_output_aliases_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the // instruction is held. absl::flat_hash_map handle_to_index_; @@ -778,48 +701,6 @@ class XlaBuilder { const Shape& shape, const string& name); friend XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); - template - friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value); - template - friend XlaOp ConstantR1(XlaBuilder* builder, - absl::Span values); - friend XlaOp ConstantR1(XlaBuilder* builder, - const tensorflow::core::Bitmap& values); - template - friend XlaOp ConstantR2( - XlaBuilder* builder, - std::initializer_list> values); - template - friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, - const Array& values, - const Layout& layout); - template - friend XlaOp ConstantFromArray(XlaBuilder* builder, - const Array& values); - template - friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, - const Array2D& values, - const Layout& layout); - template - friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder, - const Array2D& values); - template - friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, - const Array3D& values, - const Layout& layout); - template - friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder, - const Array3D& values); - template - friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, - const Array4D& values, - const Layout& layout); - template - friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder, - const Array4D& values); - - template - friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); friend XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); @@ -849,9 +730,14 @@ class XlaBuilder { friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); + friend XlaOp DynamicSlice(const XlaOp& operand, + absl::Span start_indices, + absl::Span slice_sizes); friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); + friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices); friend XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, int64 dimension); @@ -881,23 +767,25 @@ class XlaBuilder { const PrecisionConfig* precision_config); friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count, const PrecisionConfig* precision_config); + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config); + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config); friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, @@ -906,9 +794,13 @@ class XlaBuilder { absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config); + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span fft_length); + friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool unit_diagonal, + TriangularSolveOptions::Transpose transpose_a); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config); friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, @@ -987,6 +879,7 @@ class XlaBuilder { friend XlaOp CollectivePermute( const XlaOp& operand, const std::vector>& source_target_pairs); + friend XlaOp ReplicaId(XlaBuilder* builder); friend XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, @@ -1017,6 +910,8 @@ class XlaBuilder { friend XlaOp Tanh(const XlaOp& operand); friend XlaOp Real(const XlaOp& operand); friend XlaOp Imag(const XlaOp& operand); + friend XlaOp Sqrt(const XlaOp& operand); + friend XlaOp Rsqrt(const XlaOp& operand); friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions); friend XlaOp IsFinite(const XlaOp& operand); @@ -1033,6 +928,9 @@ class XlaBuilder { friend XlaOp Rev(const XlaOp& operand, absl::Span dimensions); friend XlaOp Sort(const XlaOp& keys, absl::Span values, int64 dimension); + friend XlaOp Sort(absl::Span operands, + const XlaComputation& comparator, int64 dimension, + bool is_stable); friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); friend XlaOp Map(XlaBuilder* builder, absl::Span operands, const XlaComputation& computation, @@ -1290,10 +1188,15 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, // The size of the slice in each dimension is passed in 'slice_sizes', // which specify the end point of exclusive slice intervals in each // dimension [start, start + size). -// The shape of 'start_indices' must be rank == 1, with dimension size -// equal to the rank of the 'operand'. +// The shape of each element of 'start_indices' must be scalar, with the span +// size equal to the rank of the 'operand'. All elements of 'start_indices' must +// have the same shape. // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicSlice(const XlaOp& operand, absl::Span start_indices, + absl::Span slice_sizes); + +ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); @@ -1309,10 +1212,15 @@ XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] // [7 8 9] [7 8 9 ] // -// The shape of 'start_indices' must be rank == 1, with dimension size -// equal to the rank of the 'operand'. +// The shape of each element of 'start_indices' must be scalar, with the span +// size equal to the rank of the 'operand'. All elements of 'start_indices' must +// have the same shape. // Slice index calculations are computed modulo update dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices); + +ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); @@ -1372,7 +1280,7 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller @@ -1381,6 +1289,7 @@ XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, int64 feature_group_count = 1, + int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller @@ -1388,7 +1297,7 @@ XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller @@ -1397,7 +1306,7 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller @@ -1409,6 +1318,7 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, + int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and @@ -1416,6 +1326,32 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span fft_length); +// Solves systems of linear equations with lower or upper triangular coefficient +// matrices by forward- or back-substitution. Broadcasting along leading +// dimensions, this routine solves for x in one of the matrix systems +// `op(a) * x = b`, or `x * op(a) = b`, +// for the variable `x` given `a` and `b`, where `op(a)` is either +// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`. +// +// * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form +// square matrices. If `lower` is true (false), then the strictly upper +// (lower) triangular part of each innermost matrix in `a` is assumed to be +// zero and is not accessed. +// * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a +// tensor of shape `[..., K, M]`. +// * `left_side` is a boolean, indicating whether to solve a system of the form +// op(a) * x = b (true) or x * op(a) = b (false). +// * `lower` is a boolean, indicating whether the argument `a` is +// lower-triangular +// (true) or upper-triangular (false). +// * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be +// 1 and not accessed. +// * `transpose_a` indicates which function `op` we use to transform the tensor +// `a`: the identity function, transpose(a), or conjugate(transpose(a)) +XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool unit_diagonal, + TriangularSolveOptions::Transpose transpose_a); + // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. XlaOp Infeed(XlaBuilder* builder, const Shape& shape, @@ -1515,9 +1451,33 @@ XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, XlaOp And(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); +// Overload to call And with 3 or more operands. We need the following somewhat +// convoluted overload set to disambiguate with the overload that takes the +// `broadcast_dimensions` optional param. +inline XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) { + return And(op1, And(op2, op3)); +} +template +XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3, + const XlaOpTs&... operands) { + return And(op1, And(op2, And(op3, operands...))); +} + XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); +// Overload to call Or with 3 or more operands. As with `And`, we need the +// following complicated overload set to handle the default arg in the `Or` +// overload above. +inline XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) { + return Or(op1, Or(op2, op3)); +} +template +XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3, + const XlaOpTs&... operands) { + return Or(op1, Or(op2, Or(op3, operands...))); +} + XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); @@ -1610,6 +1570,9 @@ XlaOp CollectivePermute( const XlaOp& operand, const std::vector>& source_target_pairs); +// Enqueues an operation that returns the replica ID. +XlaOp ReplicaId(XlaBuilder* builder); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -1677,14 +1640,24 @@ XlaOp Real(const XlaOp& operand); // Enqueues an imaginary-part instruction onto the computation. XlaOp Imag(const XlaOp& operand); +// Enqueues a sqrt computation onto the computation. +XlaOp Sqrt(const XlaOp& operand); + +// Enqueues a rsqrt computation onto the computation. +XlaOp Rsqrt(const XlaOp& operand); + // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); -// Enqueues an operator that tests if the operand's values are finite, i.e., -// not Inf or NaN. Defined only for floating-point types. Returns an array of -// booleans with the same shape where entries are true iff the corresponding -// entry was NaN. +// Enqueues an operator that tests if the operand's values are finite, i.e., not +// +/-Inf or NaN. Returns an array of booleans with the same shape where +// entries are true iff the corresponding entry was not infinite or NaN. +// +// Defined only for real-valued (i.e. not complex) floating-point types; raises +// an error for other types. +// +// See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h. XlaOp IsFinite(const XlaOp& operand); // Enqueues an iota operation onto the computation. @@ -1720,7 +1693,7 @@ XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // of keys, in ascending order. // * If the keys have higher rank, the keys are sorted along the provided // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension -// value of 0 will indepenently sort every column, and a dimension value of 1 +// value of 0 will independently sort every column, and a dimension value of 1 // will independently sort each row. If no dimension number is provided, then // the last dimension is chosen by default. // @@ -1730,9 +1703,39 @@ XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // * The result is a tuple that consists of a sorted tensor of keys (along the // provided dimension, as above) as the first element, and tensors with their // corresponding values as the other elements. +ABSL_DEPRECATED("Use form with comparator computation instead") XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); +// Enqueues a sort instruction onto the computation, using 'comparator' for +// comparisons. 'comparator' needs to define a strict weak order. 'is_stable' +// determines whether the stable sorting should be used. +// If only one operand is provided: +// * If the operand is a rank-1 tensor (an array), the result is a sorted array. +// The resulting sorting order has the property that for all index positions +// i, j with i < j, either +// comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or +// comparator(value[i], value[j]) = true. +// * If the operand has higher rank, the operand is sorted along the provided +// dimension. For example, for a rank-2 tensor (a matrix), a dimension value +// of 0 will independently sort every column, and a dimension value of 1 will +// independently sort each row. If no dimension number is provided, then the +// last dimension is chosen by default. For the dimension which is sorted, the +// same sorting order applies as in the rank-1 case. +// +// If more than one operand is provided: +// * All operands must be tensors with the same dimensions. The element types of +// the tensors may be different. +// * The result is a tuple that consists of the operands in sorted order (along +// the provided dimension, as above). The same permutation as implied by the +// comparison computation is applied to all operand tensors. When comparing +// two index positions, 'comparator' is called with 2 * n scalar parameters, +// where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at +// two index positions. +// Default comparator computations can be found in lib/comparators.h +XlaOp Sort(absl::Span operands, const XlaComputation& comparator, + int64 dimension = -1, bool is_stable = false); + // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -1871,81 +1874,6 @@ XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); // Implementation details below this point. // -template -XlaOp XlaBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(LiteralUtil::CreateR0(value)); -} - -template -XlaOp XlaBuilder::ConstantR1(absl::Span values) { - return ConstantLiteral(LiteralUtil::CreateR1(values)); -} - -template -XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { - Literal literal(ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), {length})); - literal.PopulateWithValue(value); - return ConstantLiteral(literal); -} - -inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { - return ConstantLiteral(LiteralUtil::CreateR1(values)); -} - -template -XlaOp XlaBuilder::ConstantR2( - std::initializer_list> values) { - return ConstantLiteral(LiteralUtil::CreateR2(values)); -} - -template -XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, - const Layout& layout) { - return ConstantLiteral( - LiteralUtil::CreateFromArrayWithLayout(values, layout)); -} - -template -XlaOp XlaBuilder::ConstantFromArray(const Array& values) { - return ConstantLiteral(LiteralUtil::CreateFromArray(values)); -} - -template -XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { - return ConstantLiteral( - LiteralUtil::CreateFromArrayWithLayout(values, layout)); -} - -template -XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { - return ConstantLiteral(LiteralUtil::CreateR2FromArray2D(values)); -} - -template -XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { - return ConstantLiteral( - LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); -} - -template -XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D& values) { - return ConstantFromArray(values); -} - -template -XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { - return ConstantFromArrayWithLayout(values, layout); -} - -template -XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { - return ConstantFromArray(values); -} - // Free function template implementations. template diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index b3f5be300d3f15397ad33858a6a9cab5f6029688..c9fa738a19d0928d56ac4b98beb5fc0ed195518b 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -39,7 +40,8 @@ using ::testing::HasSubstr; class XlaBuilderTest : public ::testing::Test { protected: StatusOr> BuildHloModule(XlaBuilder* b) { - TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build()); + TF_ASSIGN_OR_RETURN(XlaComputation computation, + b->Build(/*remove_dynamic_dimensions=*/false)); const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( @@ -50,7 +52,8 @@ class XlaBuilderTest : public ::testing::Test { // Overload which explicitly specifies the root instruction. StatusOr> BuildHloModule(XlaBuilder* b, XlaOp root) { - TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build(root)); + TF_ASSIGN_OR_RETURN(XlaComputation computation, + b->Build(root, /*remove_dynamic_dimensions=*/false)); const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( @@ -132,6 +135,38 @@ TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) { op::ShiftRightLogical(op::Constant(), op::Constant())); } +TEST_F(XlaBuilderTest, VariadicAnd) { + XlaBuilder b(TestName()); + Shape s = ShapeUtil::MakeShape(PRED, {}); + And(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"), + Parameter(&b, 2, s, "p2")); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + // Don't specify in the test whether And(x, y, z) is right- or + // left-associative; accept either one. + EXPECT_THAT( + module->entry_computation()->root_instruction(), + ::testing::AnyOf(op::And(op::Parameter(0), + op::And(op::Parameter(1), op::Parameter(2))), + op::And(op::And(op::Parameter(0), op::Parameter(1)), + op::Parameter(2)))); +} + +TEST_F(XlaBuilderTest, VariadicOr) { + XlaBuilder b(TestName()); + Shape s = ShapeUtil::MakeShape(PRED, {}); + Or(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"), + Parameter(&b, 2, s, "p2")); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + // Don't specify in the test whether Or(x, y, z) is right- or + // left-associative; accept either one. + EXPECT_THAT( + module->entry_computation()->root_instruction(), + ::testing::AnyOf( + op::Or(op::Parameter(0), op::Or(op::Parameter(1), op::Parameter(2))), + op::Or(op::Or(op::Parameter(0), op::Parameter(1)), + op::Parameter(2)))); +} + TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) { XlaBuilder b(TestName()); ConstantR0(&b, 1) >> ConstantR0(&b, 2); @@ -446,6 +481,461 @@ TEST_F(XlaBuilderTest, ProtoMatches) { EXPECT_EQ(c0_string, c1_string); } +TEST_F(XlaBuilderTest, DynamicParameter) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {6})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + Parameter(&b, 1, ShapeUtil::MakeShape(U32, {}), "p1"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/1, + /*dynamic_size_param_index=*/{}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/p0)); + const Shape& param_shape = module->entry_computation() + ->parameter_instruction(0) + ->shape() + .tuple_shapes(1); + EXPECT_TRUE(param_shape.is_dynamic_dimension(0)); +} + +TEST_F(XlaBuilderTest, DynamicUnary) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + Neg(gte); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(0)); +} + +TEST_F(XlaBuilderTest, DynamicBinary) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {5}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Add(gte0, gte1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(0)); +} + +TEST_F(XlaBuilderTest, DynamicBinaryHasBroadcast) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4}), ShapeUtil::MakeShape(F32, {5}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Add(gte0, gte1, {0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicBroadcast) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + BroadcastInDim(gte, /*out_dim_size=*/{3, 5, 4}, + /*broadcast_dimensions=*/{1, 2}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicBinaryHasDegenerateBroadcast) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {10}), ShapeUtil::MakeShape(F32, {1, 15}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Add(gte0, gte1, /*broadcast_dimensions=*/{0}); // f32[<=10, 15] + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelectOnlyPredDynamic) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {10}), ShapeUtil::MakeShape(F32, {10}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + + Select(gte0, gte1, gte1); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicPad) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto pad_val = ConstantR0(&b, -1); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + PaddingConfig padding_config; + for (int i = 0; i < 2; i++) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(0); + dimension->set_edge_padding_high(0); + dimension->set_interior_padding(0); + } + Pad(gte, pad_val, padding_config); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicConvolution) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {1, 2, 2, 128}), + ShapeUtil::MakeShape(F32, {2, 2, 128, 8}), ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/2)); + auto input = GetTupleElement(p0, 0); + auto filter = GetTupleElement(p0, 1); + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), + {true, false, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicDot) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 3, 4}), + ShapeUtil::MakeShape(F32, {2, 4, 5}), ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + + auto lhs = GetTupleElement(p0, 0); + auto rhs = GetTupleElement(p0, 1); + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + DotGeneral(lhs, rhs, dnums); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {true, true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicReduce) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4, 3}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto init = ConstantR0(&b, 0); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + auto gte = GetTupleElement(p0, 0); + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + Reduce(gte, init, sum, {0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicReduceWindow) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 4, 8}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto init = ConstantR0(&b, 0.f); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + ReduceWindow(gte, init, sum, /*window_dimensions=*/{1, 2, 4}, + /*window_strides=*/{1, 1, 1}, Padding::kValid); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelectAndScatter) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 4, 8}), + ShapeUtil::MakeShape(F32, {2, 2, 2}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto init = ConstantR0(&b, 0.f); + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + XlaBuilder bge(TestName()); + Ge(Parameter(&bge, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bge, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto ge, bge.Build()); + + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto source = GetTupleElement(p0, 1); + SelectAndScatter(gte0, ge, {1, 2, 4}, {1, 2, 4}, Padding::kValid, source, + init, sum); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicReshape) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/2)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/3)); + auto gte = GetTupleElement(p0, 0); // f32[2, 3, <=4, <=5, 6] + Reshape(gte, /*new_sizes=*/{6, 4, 1, 5, 2, 3}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(1)); + EXPECT_TRUE(result_shape.is_dynamic_dimension(3)); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), + {false, true, false, true, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelect) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {4, 5, 6}), + ShapeUtil::MakeShape(F32, {4, 5, 6}), ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/1)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Select(pred, gte0, gte1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(1)); + EXPECT_FALSE(result_shape.is_dynamic_dimension(2)); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {4, 5, 6}), + ShapeUtil::MakeShape(F32, {4, 5, 6}), ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/2)); + auto gte0 = GetTupleElement(p0, 0); // f32[4,<=5,6] + auto gte1 = GetTupleElement(p0, 1); // f32[4,5,<=6] + Select(pred, gte0, gte1); + Status status = BuildHloModule(&b).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Operands to select must be the same shape; " + "got f32[4,<=5,6] and f32[4,5,<=6]")); +} + +TEST_F(XlaBuilderTest, DynamicTranspose) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 5}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + Transpose(gte, /*permutation=*/{1, 0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {false, true})) + << result_shape; +} + TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { XlaBuilder b(TestName()); AfterAll(&b, {CreateToken(&b), ConstantR0(&b, 1.0)}); @@ -455,5 +945,31 @@ TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { ::testing::HasSubstr("All operands to AfterAll must be tokens")); } +TEST_F(XlaBuilderTest, CheckInputOutputAlias) { + XlaBuilder b(TestName()); + auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0"); + auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {8, 4}), "p1"); + auto add = Add(p0, p1); + auto sub = Sub(p0, p1); + auto root = Tuple(&b, {add, sub}); + + b.SetUpAlias({1}, 0, {}); + b.SetUpAlias({0}, 1, {}); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root)); + + const HloInputOutputAliasConfig& config = module->input_output_alias_config(); + EXPECT_TRUE(config.ParameterHasAlias(0, {})); + EXPECT_TRUE(config.ParameterHasAlias(1, {})); + + auto alias_p0 = config.GetAliasedOutput(0, {}); + ASSERT_TRUE(alias_p0.has_value()); + EXPECT_EQ(*alias_p0, ShapeIndex({1})); + + auto alias_p1 = config.GetAliasedOutput(1, {}); + ASSERT_TRUE(alias_p1.has_value()); + EXPECT_EQ(*alias_p1, ShapeIndex({0})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 20609cad58d920c0c272899c41efeb99d23cd490..43d9ee0d9a5e689676b00e59d7c59bb0f4e37461 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -22,49 +22,49 @@ limitations under the License. #include "tensorflow/compiler/xla/parse_flags_from_env.h" namespace xla { -namespace { -DebugOptions* flag_values; -std::vector* flag_objects; -std::once_flag flags_init; - -void SetDebugOptionsDefaults(DebugOptions* flags) { - flags->set_xla_llvm_enable_alias_scope_metadata(true); - flags->set_xla_llvm_enable_noalias_metadata(true); - flags->set_xla_llvm_enable_invariant_load_metadata(true); - flags->set_xla_llvm_disable_expensive_passes(false); - flags->set_xla_backend_optimization_level(3); - flags->set_xla_cpu_multi_thread_eigen(true); - flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); - flags->set_xla_eliminate_hlo_implicit_broadcast(true); +DebugOptions DefaultDebugOptionsIgnoringFlags() { + DebugOptions opts; + opts.set_xla_llvm_enable_alias_scope_metadata(true); + opts.set_xla_llvm_enable_noalias_metadata(true); + opts.set_xla_llvm_enable_invariant_load_metadata(true); + opts.set_xla_llvm_disable_expensive_passes(false); + opts.set_xla_backend_optimization_level(3); + opts.set_xla_cpu_multi_thread_eigen(true); + opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); + opts.set_xla_eliminate_hlo_implicit_broadcast(true); + opts.set_xla_hlo_dump_as_html(false); #ifdef INTEL_MKL - flags->set_xla_cpu_use_mkl_dnn(true); + opts.set_xla_cpu_use_mkl_dnn(true); #endif // INTEL_MKL - flags->set_xla_gpu_max_kernel_unroll_factor(4); + opts.set_xla_gpu_max_kernel_unroll_factor(4); // Set cudnn batchnorm off by default; it does not provide a performance win // on average. - flags->set_xla_gpu_use_cudnn_batchnorm(false); + opts.set_xla_gpu_use_cudnn_batchnorm(false); // Run all GPU work on one stream by default. Using multiple streams // increases memory usage and we lack strong motivating benchmarks for tuning // the heuristics needed to decide when to run on multiple streams. See // b/77879207. - flags->set_xla_gpu_disable_multi_streaming(true); + opts.set_xla_gpu_disable_multi_streaming(true); // TODO(jlebar): Disable fastmath once doing so is not a performance // regression. - flags->set_xla_cpu_enable_fast_math(true); - flags->set_xla_gpu_enable_fast_min_max(true); + opts.set_xla_cpu_enable_fast_math(true); + opts.set_xla_gpu_enable_fast_min_max(true); - flags->set_xla_force_host_platform_device_count(1); + opts.set_xla_force_host_platform_device_count(1); + return opts; } +static DebugOptions* flag_values; +static std::vector* flag_objects; +static std::once_flag flags_init; + // Allocates flag_values and flag_objects; this function must not be called more // than once - its call done via call_once. -void AllocateFlags() { - flag_values = new DebugOptions; - - SetDebugOptionsDefaults(flag_values); +static void AllocateFlags() { + flag_values = new DebugOptions(DefaultDebugOptionsIgnoringFlags()); // Returns a lambda that calls "member_setter" on "flag_values" with the // argument passed in to the lambda. @@ -128,24 +128,17 @@ void AllocateFlags() { tensorflow::Flag( "xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(), "With xla_generate_hlo_graph, dump the graphs into this path."), - tensorflow::Flag( - "xla_hlo_dump_as_graphdef", - bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), - flag_values->xla_hlo_dump_as_graphdef(), - "Dump HLO graphs as TensorFlow GraphDefs."), + tensorflow::Flag("xla_hlo_dump_as_html", + bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_html), + flag_values->xla_hlo_dump_as_html(), + "Dump HLO graphs as an HTML (DOT rendered into SVG " + "inlined in HTML)."), tensorflow::Flag( "xla_hlo_graph_sharding_color", bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), flag_values->xla_hlo_graph_sharding_color(), "Assign colors based on sharding assignments when generating the " "HLO graphs."), - tensorflow::Flag( - "xla_hlo_tfgraph_device_scopes", - bool_setter_for(&DebugOptions::set_xla_hlo_tfgraph_device_scopes), - flag_values->xla_hlo_tfgraph_device_scopes(), - "When generating TensorFlow HLO graphs, if the HLO instructions " - "are assigned to a specific device, prefix the name scope with " - "\"devX\" with X being the device ordinal."), tensorflow::Flag( "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), "HLO modules matching this regex will be dumped to LOG(INFO)."), @@ -202,6 +195,16 @@ void AllocateFlags() { "Comma-separated list of hlo passes to be disabled. These names " "must exactly match the passes' names; no whitespace around " "commas."), + tensorflow::Flag( + "xla_disable_all_hlo_passes", + bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false, + "Disables all HLO passes. Notes that some passes are necessary for " + "correctness and the invariants that must be satisfied by 'fully " + "optimized' HLO are different for different devices and may change " + "over time. The only 'guarantee', such as it is, is that if you " + "compile XLA and dump the optimized HLO for some graph, you should " + "be able to run it again on the same device with the same build of " + "XLA."), tensorflow::Flag( "xla_embed_ir_in_executable", bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), @@ -344,8 +347,6 @@ void AllocateFlags() { ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } -} // namespace - void AppendDebugOptionsFlags(std::vector* flag_list) { std::call_once(flags_init, &AllocateFlags); flag_list->insert(flag_list->end(), flag_objects->begin(), diff --git a/tensorflow/compiler/xla/debug_options_flags.h b/tensorflow/compiler/xla/debug_options_flags.h index 60e59abc2a2e0f1cce3de1afc928f9fe36f75b33..dbf86a40f052af09c61da0e1abb3116ef5214357 100644 --- a/tensorflow/compiler/xla/debug_options_flags.h +++ b/tensorflow/compiler/xla/debug_options_flags.h @@ -29,7 +29,10 @@ void AppendDebugOptionsFlags(std::vector* flag_list); // Fetches a DebugOptions proto message from flags provided to the program. // Flags must be registered with the flags parser using AppendDebugOptionsFlags // first. -xla::DebugOptions GetDebugOptionsFromFlags(); +DebugOptions GetDebugOptionsFromFlags(); + +// Gets a DebugOptions proto that reflects the defaults as if no flags were set. +DebugOptions DefaultDebugOptionsIgnoringFlags(); } // namespace xla diff --git a/tensorflow/compiler/xla/error_spec.h b/tensorflow/compiler/xla/error_spec.h index a1463aa15941b9c265db94e2eb3cc176fab6695b..4359f3b7deb8e585494cb2a9c7115eac6a312c8e 100644 --- a/tensorflow/compiler/xla/error_spec.h +++ b/tensorflow/compiler/xla/error_spec.h @@ -30,6 +30,19 @@ struct ErrorSpec { // In effect, this allows the tested operation to produce incorrect results // for inputs outside its mathematical domain. bool relaxed_nans; + + // If this is true, then we treat each +/-inf in the actual result as + // equivalent to our choice of either +/-inf or the min/max floating-point + // value. + // + // If the expected result is +/-inf, the actual result must still be +/-inf. + // + // In effect, this allows the tested operation to overflow, so long as it's + // overflowing on "large" values. + // + // (We could have a symmetric more_infs_ok flag if necessary; right now it + // appears not to be.) + bool fewer_infs_ok = false; }; } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 0f9b591c70d4fd96147958d18bd5fb7dd78a7f3f..230f3b202a4b531c381665471c3856c3feba5a3a 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -77,7 +77,7 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const { } ExecutableRunOptions& ExecutableRunOptions::set_device_assignment( - DeviceAssignment* device_assignment) { + const DeviceAssignment* device_assignment) { device_assignment_ = device_assignment; return *this; } diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index ba3217f31b55bd1428f67da6154a46c8bc304053..1e744953bd3be58afba5b81c0e2a8ba26665f9c4 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -16,9 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ -// Pulls in the ::stream_executor -> ::xla::se namespace alias. -#include "tensorflow/compiler/xla/types.h" - // These classes are forward declared so that ExecutableRunOptions can be linked // into an XLA-compiled binary without having to link all of the pointed-to // objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't @@ -28,12 +25,6 @@ class Stream; class Platform; } // namespace stream_executor -namespace tensorflow { -namespace thread { -class ThreadPool; -} // namespace thread -} // namespace tensorflow - namespace Eigen { struct ThreadPoolDevice; } // namespace Eigen @@ -83,7 +74,7 @@ class ExecutableRunOptions { ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile); ExecutableRunOptions& set_device_assignment( - DeviceAssignment* device_assignment); + const DeviceAssignment* device_assignment); const DeviceAssignment* device_assignment() const; ExecutableRunOptions& set_rng_seed(int rng_seed); @@ -92,7 +83,7 @@ class ExecutableRunOptions { private: DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; - DeviceAssignment* device_assignment_ = nullptr; + const DeviceAssignment* device_assignment_ = nullptr; stream_executor::Stream* stream_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index 1fea816a803bfb75b9721393cef8c4dfc249268d..c34e84efc80ba970624d80802841d6ec534b6fd0 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -104,9 +104,9 @@ class Sharding(object): ValueError: The tensor to split was smaller in the split dimension than the number of devices to split over. """ - tensor.shape.assert_is_fully_defined() shape = tensor.shape.as_list() - if shape[split_dimension] < num_devices: + if (shape[split_dimension] is not None and + shape[split_dimension] < num_devices): raise ValueError('Split dimension was smaller than the required number ' 'of splits: shape=%r, dimension=%r, num_devices=%r' % (shape, split_dimension, num_devices)) diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml index 267701e9c0e42a21d2cda6238520f6a9692e7e76..d756cd74c98b98a6fda099690d966562bd694e2c 100644 --- a/tensorflow/compiler/xla/g3doc/_book.yaml +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -25,6 +25,8 @@ upper_tabs: path: /xla/operation_semantics - title: Shapes and layout path: /xla/shapes + - title: Tiled layout + path: /xla/tiled_layout - title: Using AOT compilation path: /xla/tfcompile - heading: Tutorials diff --git a/tensorflow/compiler/xla/g3doc/broadcasting.md b/tensorflow/compiler/xla/g3doc/broadcasting.md index 2870869a2cef13a9105b9dc9fa4d657834288f86..5c0525c1e9adf9f37d945170d05e7c18fa3d8852 100644 --- a/tensorflow/compiler/xla/g3doc/broadcasting.md +++ b/tensorflow/compiler/xla/g3doc/broadcasting.md @@ -168,7 +168,7 @@ consult the Broadcasting of a lower-rank array to a higher-rank array **and** broadcasting using degenerate dimensions can both be performed in the same binary operation. -For example, a vector of size 4 and an matrix of size 1x2 can be added together +For example, a vector of size 4 and a matrix of size 1x2 can be added together using broadcast dimensions value of (0): |1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector. @@ -176,7 +176,7 @@ using broadcast dimensions value of (0): First the vector is broadcast up to rank 2 (matrix) using the broadcast dimensions. The single value (0) in the broadcast dimensions indicates that dimension zero of the vector matches to dimension zero of the matrix. This -produces an matrix of size 4xM where the value M is chosen to match the +produces a matrix of size 4xM where the value M is chosen to match the corresponding dimension size in the 1x2 array. Therefore, a 4x2 matrix is produced: diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index d888b1f23f36f33ef94ef0e22374e0c796e47a89..db90d184b5218614ac49363ebf2a7e25fffe44de 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -38,25 +38,25 @@ Alltoall is a collective operation that sends data from all cores to all cores. It has two phases: 1. the scatter phase. On each core, the operand is split into `split_count` - number of blocks along the `split_dimensions`, and the blocks are scattered - to all cores, e.g., the ith block is send to the ith core. +number of blocks along the `split_dimensions`, and the blocks are scattered +to all cores, e.g., the ith block is send to the ith core. 2. the gather phase. Each core concatenates the received blocks along the - `concat_dimension`. +`concat_dimension`. The participating cores can be configured by: - `replica_groups`: each ReplicaGroup contains a list of replica id. If empty, - all replicas belong to one group in the order of 0 - (n-1). Alltoall will be - applied within subgroups in the specified order. For example, replica - groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica - 1, 2, 3, and in the gather phase, the received blocks will be concatenated - in the order of 1, 2, 3; another Alltoall will be applied within replica 4, - 5, 0, and the concatenation order is 4, 5, 0. +all replicas belong to one group in the order of 0 - (n-1). Alltoall will be +applied within subgroups in the specified order. For example, replica +groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica +1, 2, 3, and in the gather phase, the received blocks will be concatenated +in the order of 1, 2, 3; another Alltoall will be applied within replica 4, +5, 0, and the concatenation order is 4, 5, 0. Prerequisites: - The dimension size of the operand on the split_dimension is divisible by - split_count. +split_count. - The operand's shape is not tuple. `AllToAll(operand, split_dimension, concat_dimension, split_count, @@ -93,7 +93,7 @@ AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4); ```
- +
In this example, there are 4 cores participating the Alltoall. On each core, the @@ -387,34 +387,34 @@ For example, let v be an array of 24 elements: ``` let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}}, - {{20, 21, 22}, {25, 26, 27}}, - {{30, 31, 32}, {35, 36, 37}}, - {{40, 41, 42}, {45, 46, 47}}}; +{{20, 21, 22}, {25, 26, 27}}, +{{30, 31, 32}, {35, 36, 37}}, +{{40, 41, 42}, {45, 46, 47}}}; // Collapse to a single dimension, leaving one dimension. let v012 = Collapse(v, {0,1,2}); then v012 == f32[24] {10, 11, 12, 15, 16, 17, - 20, 21, 22, 25, 26, 27, - 30, 31, 32, 35, 36, 37, - 40, 41, 42, 45, 46, 47}; +20, 21, 22, 25, 26, 27, +30, 31, 32, 35, 36, 37, +40, 41, 42, 45, 46, 47}; // Collapse the two lower dimensions, leaving two dimensions. let v01 = Collapse(v, {0,1}); then v01 == f32[4x6] {{10, 11, 12, 15, 16, 17}, - {20, 21, 22, 25, 26, 27}, - {30, 31, 32, 35, 36, 37}, - {40, 41, 42, 45, 46, 47}}; +{20, 21, 22, 25, 26, 27}, +{30, 31, 32, 35, 36, 37}, +{40, 41, 42, 45, 46, 47}}; // Collapse the two higher dimensions, leaving two dimensions. let v12 = Collapse(v, {1,2}); then v12 == f32[8x3] {{10, 11, 12}, - {15, 16, 17}, - {20, 21, 22}, - {25, 26, 27}, - {30, 31, 32}, - {35, 36, 37}, - {40, 41, 42}, - {45, 46, 47}}; +{15, 16, 17}, +{20, 21, 22}, +{25, 26, 27}, +{30, 31, 32}, +{35, 36, 37}, +{40, 41, 42}, +{45, 46, 47}}; ``` @@ -441,9 +441,9 @@ replicas. Note that there are the following restrictions on the `source_target_pair`: - Any two pairs should not have the same target replica id, and they should - not have the same source replica id. +not have the same source replica id. - If a replica id is not a target in any pair, then the output on that replica - is a tensor consists of 0(s) with the same shape as the input. +is a tensor consists of 0(s) with the same shape as the input. ## Concatenate @@ -480,25 +480,25 @@ Concat({{2, 3}, {4, 5}, {6, 7}}, 0) ``` let a = { - {1, 2}, - {3, 4}, - {5, 6}, +{1, 2}, +{3, 4}, +{5, 6}, }; let b = { - {7, 8}, +{7, 8}, }; Concat({a, b}, 0) >>> { - {1, 2}, - {3, 4}, - {5, 6}, - {7, 8}, +{1, 2}, +{3, 4}, +{5, 6}, +{7, 8}, } ``` Diagram:
- +
## Conditional @@ -548,17 +548,23 @@ Computes a convolution of the kind used in neural networks. Here, a convolution can be thought of as a n-dimensional window moving across a n-dimensional base area and a computation is performed for each possible position of the window. -| Arguments | Type | Semantics | -| --------------------- | -------------------- | ----------------------------- | -| `lhs` | `XlaOp` | rank n+2 array of inputs | -| `rhs` | `XlaOp` | rank n+2 array of kernel | -: : : weights : -| `window_strides` | `ArraySlice` | n-d array of kernel strides | -| `padding` | `ArraySlice< | n-d array of (low, high) | -: : pair>` : padding : -| `lhs_dilation` | `ArraySlice` | n-d lhs dilation factor array | -| `rhs_dilation` | `ArraySlice` | n-d rhs dilation factor array | -| `feature_group_count` | int64 | the number of feature groups | +| Arguments | Type | Semantics | +| --------------------- | ------------------------ | ------------------------ | +| `lhs` | `XlaOp` | rank n+2 array of inputs | +| `rhs` | `XlaOp` | rank n+2 array of kernel | +: : : weights : +| `window_strides` | `ArraySlice` | n-d array of kernel | +: : : strides : +| `padding` | `ArraySlice< pair>` : padding : +| `lhs_dilation` | `ArraySlice` | n-d lhs dilation factor | +: : : array : +| `rhs_dilation` | `ArraySlice` | n-d rhs dilation factor | +: : : array : +| `feature_group_count` | int64 | the number of feature | +: : : groups : +| `batch_group_count` | int64 | the number of batch | +: : : groups : Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2 array describing the base area. This is called the input, even though of course @@ -566,20 +572,20 @@ the rhs is also an input. In a neural network, these are the input activations. The n+2 dimensions are, in this order: * `batch`: Each coordinate in this dimension represents an independent input - for which convolution is carried out. +for which convolution is carried out. * `z/depth/features`: Each (y,x) position in the base area has a vector - associated to it, which goes into this dimension. +associated to it, which goes into this dimension. * `spatial_dims`: Describes the `n` spatial dimensions that define the base - area that the window moves across. +area that the window moves across. The `rhs` argument is a rank n+2 array describing the convolutional filter/kernel/window. The dimensions are, in this order: * `output-z`: The `z` dimension of the output. * `input-z`: The size of this dimension times `feature_group_count` should - equal the size of the `z` dimension in lhs. +equal the size of the `z` dimension in lhs. * `spatial_dims`: Describes the `n` spatial dimensions that define the n-d - window that moves across the base area. +window that moves across the base area. The `window_strides` argument specifies the stride of the convolutional window in the spatial dimensions. For example, if the stride in the first spatial @@ -628,9 +634,22 @@ input feature dimension, and the filter would be reshaped from `[filter_height, filter_width, 1, in_channels * channel_multiplier]`. For more details, see `tf.nn.depthwise_conv2d`. +The `batch_group_count` (default value 1) argument can be used for depthwise +filters during backpropagation. `batch_group_count` needs to be a divisor of the +size of the `lhs` (input) batch dimension. If `batch_group_count` is greater +than 1, it means that the output batch dimension should be of size +`batch_group_size` where `batch_group_size = input batch / batch_group_count`. +For convolutions with `batch_group_count` greater than 1, the input batch size +must evenly divide into batch_group_size and output feature size, which implies +that the output feature size must be equal to batch_group_count. Conceptually, +this can be achieved by performing the usual convolution, and then scraping +`batch_group_size` number of elements on the diagonal of the matrix formed by +output batch and output feature. + The output shape has these dimensions, in this order: -* `batch`: Same size as `batch` on the input (`lhs`). +* `batch`: The size of this dimension times `batch_group_count` should equal + the size of the `batch` dimension in lhs. * `z`: Same size as `output-z` on the kernel (`rhs`). * `spatial_dims`: One value for each valid placement of the convolutional window. @@ -658,15 +677,15 @@ Here is pseudo-code for a 2d convolution with padding and striding: ``` for (b, oz, oy, ox) { // output coordinates - value = 0; - for (iz, ky, kx) { // kernel coordinates and input z - iy = oy*stride_y + ky - pad_low_y; - ix = ox*stride_x + kx - pad_low_x; - if ((iy, ix) inside the base area considered without padding) { - value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx); - } - } - output(b, oz, oy, ox) = value; +value = 0; +for (iz, ky, kx) { // kernel coordinates and input z +iy = oy*stride_y + ky - pad_low_y; +ix = ox*stride_x + kx - pad_low_x; +if ((iy, ix) inside the base area considered without padding) { +value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx); +} +} +output(b, oz, oy, ox) = value; } ``` @@ -777,19 +796,19 @@ Here is an example of an implementation of `myfunc`: ``` extern "C" void myfunc(void* out, void** in) { - float (&x)[2] = *static_cast(in[0]); - float (&y)[2][3] = *static_cast(in[1]); - EXPECT_EQ(1, x[0]); - EXPECT_EQ(2, x[1]); - EXPECT_EQ(10, y[0][0]); - EXPECT_EQ(20, y[0][1]); - EXPECT_EQ(30, y[0][2]); - EXPECT_EQ(40, y[1][0]); - EXPECT_EQ(50, y[1][1]); - EXPECT_EQ(60, y[1][2]); - float (&z)[3][3] = *static_cast(out); - z[0][0] = x[1] + y[1][0]; - // ... +float (&x)[2] = *static_cast(in[0]); +float (&y)[2][3] = *static_cast(in[1]); +EXPECT_EQ(1, x[0]); +EXPECT_EQ(2, x[1]); +EXPECT_EQ(10, y[0][0]); +EXPECT_EQ(20, y[0][1]); +EXPECT_EQ(30, y[0][2]); +EXPECT_EQ(40, y[1][0]); +EXPECT_EQ(50, y[1][1]); +EXPECT_EQ(60, y[1][2]); +float (&z)[3][3] = *static_cast(out); +z[0][0] = x[1] + y[1][0]; +// ... } ``` @@ -856,44 +875,40 @@ DotGeneral performs the sum of products over contracting dimensions specified in 'dimension_numbers'. Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need -to be the same, but must be listed in the same order in both -'lhs/rhs_contracting_dimensions' arrays and have the same dimension sizes. -There must be exactly one contracting dimension on both 'lhs' and 'rhs'. +to be the same and but must have the same dimension sizes. Example with contracting dimension numbers: ``` lhs = { {1.0, 2.0, 3.0}, - {4.0, 5.0, 6.0} } +{4.0, 5.0, 6.0} } rhs = { {1.0, 1.0, 1.0}, - {2.0, 2.0, 2.0} } +{2.0, 2.0, 2.0} } DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(1); dnums.add_rhs_contracting_dimensions(1); DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0}, - {15.0, 30.0} } +{15.0, 30.0} } ``` -Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same -dimension number, must be listed in the same order in both arrays, must -have the same dimension sizes, and must be ordered before contracting and -non-contracting/non-batch dimension numbers. +Associated batch dimension numbers from the 'lhs' and 'rhs' must +have the same dimension sizes. Example with batch dimension numbers (batch size 2, 2x2 matrices): ``` lhs = { { {1.0, 2.0}, - {3.0, 4.0} }, - { {5.0, 6.0}, - {7.0, 8.0} } } +{3.0, 4.0} }, +{ {5.0, 6.0}, +{7.0, 8.0} } } rhs = { { {1.0, 0.0}, - {0.0, 1.0} }, - { {1.0, 0.0}, - {0.0, 1.0} } } +{0.0, 1.0} }, +{ {1.0, 0.0}, +{0.0, 1.0} } } DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(2); @@ -902,9 +917,9 @@ dnums.add_lhs_batch_dimensions(0); dnums.add_rhs_batch_dimensions(0); DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0}, - {3.0, 4.0} }, - { {5.0, 6.0}, - {7.0, 8.0} } } +{3.0, 4.0} }, +{ {5.0, 6.0}, +{7.0, 8.0} } } ``` | Input | Output | Semantics | @@ -929,21 +944,21 @@ dimension: [start, start + size). The shape of `start_indices` must be rank == `DynamicSlice(operand, start_indices, size_indices)` -| Arguments | Type | Semantics | -| --------------- | ------------------- | ----------------------------------- | -| `operand` | `XlaOp` | N dimensional array of type T | -| `start_indices` | `XlaOp` | Rank 1 array of N integers | -: : : containing the starting indices of : -: : : the slice for each dimension. Value : -: : : must be greater than or equal to : -: : : zero. : -| `size_indices` | `ArraySlice` | List of N integers containing the | -: : : slice size for each dimension. Each : -: : : value must be strictly greater than : -: : : zero, and start + size must be less : -: : : than or equal to the size of the : -: : : dimension to avoid wrapping modulo : -: : : dimension size. : +| Arguments | Type | Semantics | +| --------------- | --------------------- | ---------------------------------- | +| `operand` | `XlaOp` | N dimensional array of type T | +| `start_indices` | sequence of N `XlaOp` | List of N scalar integers | +: : : containing the starting indices of : +: : : the slice for each dimension. : +: : : Value must be greater than or : +: : : equal to zero. : +| `size_indices` | `ArraySlice` | List of N integers containing the | +: : : slice size for each dimension. : +: : : Each value must be strictly : +: : : greater than zero, and start + : +: : : size must be less than or equal to : +: : : the size of the dimension to avoid : +: : : wrapping modulo dimension size. : The effective slice indices are computed by applying the following transformation for each index `i` in `[1, N)` before performing the slice: @@ -963,22 +978,22 @@ let a = {0.0, 1.0, 2.0, 3.0, 4.0} let s = {2} DynamicSlice(a, s, {2}) produces: - {2.0, 3.0} +{2.0, 3.0} ``` 2-dimensional example: ``` let b = - { {0.0, 1.0, 2.0}, - {3.0, 4.0, 5.0}, - {6.0, 7.0, 8.0}, - {9.0, 10.0, 11.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 4.0, 5.0}, +{6.0, 7.0, 8.0}, +{9.0, 10.0, 11.0} } let s = {2, 1} DynamicSlice(b, s, {2, 2}) produces: - { { 7.0, 8.0}, - {10.0, 11.0} } +{ { 7.0, 8.0}, +{10.0, 11.0} } ``` ## DynamicUpdateSlice @@ -994,19 +1009,22 @@ the rank of `operand`. `DynamicUpdateSlice(operand, update, start_indices)` -| Arguments | Type | Semantics | -| --------------- | ------- | ------------------------------------------------ | -| `operand` | `XlaOp` | N dimensional array of type T | -| `update` | `XlaOp` | N dimensional array of type T containing the | -: : : slice update. Each dimension of update shape : -: : : must be strictly greater than zero, and start + : -: : : update must be less than or equal to the operand : -: : : size for each dimension to avoid generating : -: : : out-of-bounds update indices. : -| `start_indices` | `XlaOp` | Rank 1 array of N integers containing the | -: : : starting indices of the slice for each : -: : : dimension. Value must be greater than or equal : -: : : to zero. : +| Arguments | Type | Semantics | +| --------------- | --------------------- | ---------------------------------- | +| `operand` | `XlaOp` | N dimensional array of type T | +| `update` | `XlaOp` | N dimensional array of type T | +: : : containing the slice update. Each : +: : : dimension of update shape must be : +: : : strictly greater than zero, and : +: : : start + update must be less than : +: : : or equal to the operand size for : +: : : each dimension to avoid generating : +: : : out-of-bounds update indices. : +| `start_indices` | sequence of N `XlaOp` | List of N scalar integers | +: : : containing the starting indices of : +: : : the slice for each dimension. : +: : : Value must be greater than or : +: : : equal to zero. : The effective slice indices are computed by applying the following transformation for each index `i` in `[1, N)` before performing the slice: @@ -1027,29 +1045,29 @@ let u = {5.0, 6.0} let s = {2} DynamicUpdateSlice(a, u, s) produces: - {0.0, 1.0, 5.0, 6.0, 4.0} +{0.0, 1.0, 5.0, 6.0, 4.0} ``` 2-dimensional example: ``` let b = - { {0.0, 1.0, 2.0}, - {3.0, 4.0, 5.0}, - {6.0, 7.0, 8.0}, - {9.0, 10.0, 11.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 4.0, 5.0}, +{6.0, 7.0, 8.0}, +{9.0, 10.0, 11.0} } let u = - { {12.0, 13.0}, - {14.0, 15.0}, - {16.0, 17.0} } +{ {12.0, 13.0}, +{14.0, 15.0}, +{16.0, 17.0} } let s = {1, 1} DynamicUpdateSlice(b, u, s) produces: - { {0.0, 1.0, 2.0}, - {3.0, 12.0, 13.0}, - {6.0, 14.0, 15.0}, - {9.0, 16.0, 17.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 12.0, 13.0}, +{6.0, 14.0, 15.0}, +{9.0, 16.0, 17.0} } ``` ## Element-wise binary arithmetic operations @@ -1080,7 +1098,7 @@ When `Op` is `Rem`, the sign of the result is taken from the dividend, and the absolute value of the result is always less than the divisor's absolute value. Integer division overflow (signed/unsigned division/remainder by zero or signed -divison/remainder of `INT_SMIN` with `-1`) produces an implementation defined +division/remainder of `INT_SMIN` with `-1`) produces an implementation defined value. An alternative variant with different-rank broadcasting support exists for these @@ -1168,7 +1186,7 @@ if and only if the corresponding input element is finite. `Sign(operand)` Element-wise sign operation `x -> sgn(x)` where -$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ 0 & x = 0\\ 1 & x > 0 \end{cases}$$ +$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}$$ using the comparison operator of the element type of `operand`. @@ -1235,42 +1253,42 @@ shape of `start_indices` to be `[6,7,1]`). The bounds for the output array along dimension `i` is computed as follows: - 1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for - some `k`) then we pick the corresponding dimension bounds out of - `start_indices.shape`, skipping `index_vector_dim` (i.e. pick - `start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and - `start_indices.shape.dims`[`k`+`1`] otherwise). +1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for +some `k`) then we pick the corresponding dimension bounds out of +`start_indices.shape`, skipping `index_vector_dim` (i.e. pick +`start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and +`start_indices.shape.dims`[`k`+`1`] otherwise). - 2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for - some `k`) then we pick the corresponding bound out of `slice_sizes` after - accounting for `collapsed_slice_dims` (i.e. we pick - `adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes` - with the bounds at indices `collapsed_slice_dims` removed). +2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for +some `k`) then we pick the corresponding bound out of `slice_sizes` after +accounting for `collapsed_slice_dims` (i.e. we pick +`adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes` +with the bounds at indices `collapsed_slice_dims` removed). Formally, the operand index `In` corresponding to an output index `Out` is computed as follows: - 1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out - vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where - Combine(A, b) inserts b at position `index_vector_dim` into A. Note that - this is well defined even if `G` is empty -- if `G` is empty then `S` = - `start_indices`. - - 2. Create a starting index, `S``in`, into `operand` using `S` by - scattering `S` using `start_index_map`. More precisely: - 1. `S``in`[`start_index_map`[`k`]] = `S`[`k`] if `k` < - `start_index_map.size`. - 2. `S``in`[`_`] = `0` otherwise. - - 3. Create an index `O``in` into `operand` by scattering the indices - at the offset dimensions in `Out` according to the `collapsed_slice_dims` - set. More precisely: - 1. `O``in`[`expand_offset_dims`(`k`)] = - `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size` - (`expand_offset_dims` is defined below). - 2. `O``in`[`_`] = `0` otherwise. - 4. `In` is `O``in` + `S``in` where + is element-wise - addition. +1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out +vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where +Combine(A, b) inserts b at position `index_vector_dim` into A. Note that +this is well defined even if `G` is empty -- if `G` is empty then `S` = +`start_indices`. + +2. Create a starting index, `S``in`, into `operand` using `S` by +scattering `S` using `start_index_map`. More precisely: +1. `S``in`[`start_index_map`[`k`]] = `S`[`k`] if `k` < +`start_index_map.size`. +2. `S``in`[`_`] = `0` otherwise. + +3. Create an index `O``in` into `operand` by scattering the indices +at the offset dimensions in `Out` according to the `collapsed_slice_dims` +set. More precisely: +1. `O``in`[`expand_offset_dims`(`k`)] = +`Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size` +(`expand_offset_dims` is defined below). +2. `O``in`[`_`] = `0` otherwise. +4. `In` is `O``in` + `S``in` where + is element-wise +addition. `expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`) and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g., @@ -1282,21 +1300,21 @@ and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g., Informally, every index `Out` in the output array corresponds to an element `E` in the operand array, computed as follows: - - We use the batch dimensions in `Out` to look up a starting index from - `start_indices`. +- We use the batch dimensions in `Out` to look up a starting index from +`start_indices`. - - We use `start_index_map` to map the starting index (which may have size less - than operand.rank) to a "full" starting index into operand. +- We use `start_index_map` to map the starting index (which may have size less +than operand.rank) to a "full" starting index into operand. - - We dynamic-slice out a slice with size `slice_sizes` using the full starting - index. +- We dynamic-slice out a slice with size `slice_sizes` using the full starting +index. - - We reshape the slice by collapsing the `collapsed_slice_dims` dimensions. - Since all collapsed slice dimensions have to have bound 1 this reshape is - always legal. +- We reshape the slice by collapsing the `collapsed_slice_dims` dimensions. +Since all collapsed slice dimensions have to have bound 1 this reshape is +always legal. - - We use the offset dimensions in `Out` to index into this slice to get the - input element, `E`, corresponding to output index `Out`. +- We use the offset dimensions in `Out` to index into this slice to get the +input element, `E`, corresponding to output index `Out`. `index_vector_dim` is set to `start_indices.rank` - `1` in all of the examples that follow. More interesting values for `index_vector_dim` does not @@ -1315,7 +1333,7 @@ the output shape, and maps it to an element in the input array in the following way:
- +
We first select an (`X`,`Y`) vector from the gather indices array using `G`. @@ -1334,7 +1352,7 @@ version of the example above using a "gather indices" array of shape `[4,5,2]` would translate indices like this:
- +
Again, this acts as a batch dynamic slice `G``0` and @@ -1343,27 +1361,27 @@ Again, this acts as a batch dynamic slice `G``0` and The gather operation in XLA generalizes the informal semantics outlined above in the following ways: - 1. We can configure which dimensions in the output shape are the offset - dimensions (dimensions containing `O``0`, `O``1` in - the last example). The output batch dimensions (dimensions containing - `G``0`, `G``1` in the last example) are defined to be - the output dimensions that are not offset dimensions. +1. We can configure which dimensions in the output shape are the offset +dimensions (dimensions containing `O``0`, `O``1` in +the last example). The output batch dimensions (dimensions containing +`G``0`, `G``1` in the last example) are defined to be +the output dimensions that are not offset dimensions. - 2. The number of output offset dimensions explicitly present in the output - shape may be smaller than the input rank. These "missing" dimensions, which - are listed explicitly as `collapsed_slice_dims`, must have a slice size of - `1`. Since they have a slice size of `1` the only valid index for them is - `0` and eliding them does not introduce ambiguity. +2. The number of output offset dimensions explicitly present in the output +shape may be smaller than the input rank. These "missing" dimensions, which +are listed explicitly as `collapsed_slice_dims`, must have a slice size of +`1`. Since they have a slice size of `1` the only valid index for them is +`0` and eliding them does not introduce ambiguity. - 3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last - example) may have fewer elements than the input array rank, and an explicit - mapping dictates how the index should be expanded to have the same rank as - the input. +3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last +example) may have fewer elements than the input array rank, and an explicit +mapping dictates how the index should be expanded to have the same rank as +the input. As a final example, we use (2) and (3) to implement `tf.gather_nd`:
- +
`G``0` and `G``1` are used to slice out a starting index @@ -1442,11 +1460,11 @@ dependency between the while loops. ``` result1 = while (condition, init = init_value) { - Infeed(shape) +Infeed(shape) } result2 = while (condition, init = result1) { - Infeed(shape) +Infeed(shape) } ``` @@ -1464,7 +1482,9 @@ Infeed of the device. Builds a constant literal on device rather than a potentially large host transfer. Creates a rank 1 array of values starting at zero and incrementing by -one. +one. For floating-point types, the produced array is equivalent to +`ConvertElementType(Iota(...))` where the `Iota` is of integral type and the +conversion is to the floating-point type. Arguments | Type | Semantics ---------------- | --------------- | ------------------------------------ @@ -1853,6 +1873,20 @@ non-deterministic. Therefore, the reduction function should not be overly sensitive to reassociation. See the discussion about associativity in the context of [`Reduce`](#reduce) for more details. +## ReplicaId + +See also +[`XlaBuilder::ReplicaId`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Returns the unique ID (U32 scalar) of the replica. + + `ReplicaId()` + +The unique ID of each replica is an unsigned integer in the interval `[0, N)`, +where `N` is the number of replicas. Since all the replicas are running the same +program, a `ReplicaId()` call in the program will return a different value on +each replica. + ## Reshape See also diff --git a/tensorflow/compiler/xla/g3doc/layout_with_tiling.md b/tensorflow/compiler/xla/g3doc/tiled_layout.md similarity index 96% rename from tensorflow/compiler/xla/g3doc/layout_with_tiling.md rename to tensorflow/compiler/xla/g3doc/tiled_layout.md index 5e990851af7495ebd4417e44f1d955fcc14dadf1..21e88ceab6208cdf940826d769fd93713044d5a0 100644 --- a/tensorflow/compiler/xla/g3doc/layout_with_tiling.md +++ b/tensorflow/compiler/xla/g3doc/tiled_layout.md @@ -1,9 +1,7 @@ # Tiled layout -*Note: This doc describes how tiled layout is intended to work. Tiling is being -implemented, but this is an early effort and it is currently not even guaranteed -to get an Unimplemented error if one tries to use tiling - it may be just -silently ignored.* +Caution: Tiled layout is *pre-release* and this describes how it's intended to +work. Errors may be silently ignored.
![](images/xla_array_layout_figure1.png) diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 2a0241af3ef359c4d1c6c1ab9319b5b293110f7a..eebd8245abe759b71b3fe732943761325ea04b81 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -21,7 +21,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -141,7 +140,7 @@ namespace xla { /* static */ bool IndexUtil::IndexInBounds(const Shape& shape, absl::Span index) { - int64 rank = ShapeUtil::Rank(shape); + int64 rank = shape.rank(); if (rank != index.size()) { return false; } diff --git a/tensorflow/compiler/xla/layout.cc b/tensorflow/compiler/xla/layout.cc new file mode 100644 index 0000000000000000000000000000000000000000..000c4fdc40519214fa9fa721a8987b77b534442b --- /dev/null +++ b/tensorflow/compiler/xla/layout.cc @@ -0,0 +1,128 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/layout.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" + +namespace xla { + +TileProto Tile::ToProto() const { + TileProto tile_proto; + for (int64 i : dimensions()) { + tile_proto.add_dimensions(i); + } + return tile_proto; +} + +string Tile::ToString() const { + std::vector elements; + for (auto dim : dimensions()) { + if (dim >= 0) { + elements.push_back(std::to_string(dim)); + } else { + if (dim == kCombineDimension) { + elements.push_back("*"); + } else { + elements.push_back(absl::StrCat("Invalid value ", dim)); + } + } + } + return absl::StrCat("(", absl::StrJoin(elements, ","), ")"); +} + +/* static */ Layout Layout::CreateFromProto(const LayoutProto& proto) { + Layout layout; + layout.set_format(proto.format()); + layout.minor_to_major_.reserve(proto.minor_to_major_size()); + for (const int64 dimension : proto.minor_to_major()) { + layout.add_minor_to_major(dimension); + } + layout.set_max_sparse_elements(proto.max_sparse_elements()); + for (const TileProto& tile_proto : proto.tiles()) { + *layout.add_tiles() = Tile::CreateFromProto(tile_proto); + } + layout.set_element_size_in_bits(proto.element_size_in_bits()); + return layout; +} + +LayoutProto Layout::ToProto() const { + LayoutProto proto; + proto.set_format(format_); + proto.mutable_minor_to_major()->Reserve(minor_to_major_size()); + for (const int64 dimension : minor_to_major()) { + proto.add_minor_to_major(dimension); + } + proto.set_max_sparse_elements(max_sparse_elements_); + for (const Tile& tile : tiles()) { + *proto.add_tiles() = tile.ToProto(); + } + proto.set_element_size_in_bits(element_size_in_bits()); + return proto; +} + +string Layout::ToString() const { + if (format() == SPARSE) { + CHECK_EQ(tiles_size(), 0) << "Sparse layout should not be tiled."; + return absl::StrCat("sparse{", max_sparse_elements(), "}"); + } else if (format() == DENSE) { + string colon_string = tiles().empty() ? "" : "T"; + for (Tile tile : tiles()) { + absl::StrAppend(&colon_string, tile.ToString()); + } + if (element_size_in_bits() != 0) { + absl::StrAppend(&colon_string, "E(", element_size_in_bits(), ")"); + } + return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","), + colon_string.empty() ? "" : ":", colon_string, "}"); + } else { + CHECK_EQ(format(), INVALID_FORMAT); + return "invalid{}"; + } +} + +bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) { + if (lhs.format() != rhs.format() || + lhs.minor_to_major() != rhs.minor_to_major() || + lhs.max_sparse_elements() != rhs.max_sparse_elements()) { + return false; + } + if (!ignore_tiles_ && lhs.tiles() != rhs.tiles()) { + return false; + } + if (!ignore_element_size_ && + lhs.element_size_in_bits() != rhs.element_size_in_bits()) { + return false; + } + return true; +} + +bool Layout::operator==(const Layout& other) const { + return Equal()(*this, other); +} + +std::ostream& operator<<(std::ostream& out, const Tile& tile) { + out << tile.ToString(); + return out; +} + +std::ostream& operator<<(std::ostream& out, const Layout& layout) { + out << layout.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h new file mode 100644 index 0000000000000000000000000000000000000000..acc449b781b503142b24ed7229e3559230bb1599 --- /dev/null +++ b/tensorflow/compiler/xla/layout.h @@ -0,0 +1,234 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_ +#define TENSORFLOW_COMPILER_XLA_LAYOUT_H_ + +#include + +#include "absl/types/span.h" + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Describes a tile used in tiling-based layout. Refer to +// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for +// details. +class Tile { + public: + Tile() = default; + explicit Tile(absl::Span dimensions) + : dimensions_(dimensions.begin(), dimensions.end()) {} + + // De/Serialize a Tile to and from a TileProto. + static Tile CreateFromProto(const TileProto& tile_proto) { + return Tile(AsInt64Slice(tile_proto.dimensions())); + } + TileProto ToProto() const; + + bool operator==(const Tile& other) const { + return dimensions() == other.dimensions(); + } + bool operator!=(const Tile& other) const { return !(*this == other); } + + string ToString() const; + + // Returns the bound of the tile in the given dimension index. + int64 dimension(int i) const { return dimensions_.at(i); } + + // Returns the dimensions of the tile. + const std::vector& dimensions() const { return dimensions_; } + + Tile& add_dimensions(int64 value) { + dimensions_.push_back(value); + return *this; + } + + Tile& clear_dimensions() { + dimensions_.clear(); + return *this; + } + + // This dimension size means the corresponding dimension in the shape is + // combined with the next minor dimension before tiling is applied. + static constexpr int64 kCombineDimension = std::numeric_limits::min(); + + private: + // The bounds of the tile. + std::vector dimensions_; +}; + +class Layout { + public: + Layout() = default; + + // Constructs a dense layout with the given minor-to-major order. + explicit Layout(absl::Span minor_to_major) + : format_(DENSE), + minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {} + + // Constructs a dense tiled layout with the given minor-to-major order and + // tiles. + Layout(absl::Span minor_to_major, absl::Span tiles, + int64 element_size_in_bits = 0) + : format_(DENSE), + minor_to_major_(minor_to_major.begin(), minor_to_major.end()), + tiles_(tiles.begin(), tiles.end()), + element_size_in_bits_(element_size_in_bits) {} + + // Construct a shape from a LayoutProto. + static Layout CreateFromProto(const LayoutProto& proto); + + // Returns a LayoutProto representation of the Layout. + LayoutProto ToProto() const; + + // Returns a human-readable string that represents this layout. + string ToString() const; + + // Equal is a configurable functor to check the equality of two layouts. + // + // Examples: + // + // - Comparing two layouts ignoring their difference in tiles: + // Equal().IgnoreTiles()(layout1, layout2); + // + // - Comparing two layouts ignoring their difference in tiles and element + // size: + // Equal().IgnoreTiles().IgnoreElementSize()(layout1, layout2); + class Equal { + public: + Equal() = default; + + bool operator()(const Layout& lhs, const Layout& rhs); + + Equal& IgnoreTiles() { + ignore_tiles_ = true; + return *this; + } + + Equal& IgnoreElementSize() { + ignore_element_size_ = true; + return *this; + } + + private: + bool ignore_tiles_ = false; + bool ignore_element_size_ = false; + }; + + bool operator==(const Layout& other) const; + bool operator!=(const Layout& other) const { return !(*this == other); } + + // The following methods mirror the protobuf generated code interface for the + // message LayoutProto. This enabled easy migration of this data structure + // from a proto to a proper C++ class. + // + // TODO(b/29771030): Replace or augment these methods with a more ergonomic + // interface. + + // Methods for accessing the format. + Format format() const { return format_; } + Layout& set_format(Format value) { + format_ = value; + return *this; + } + + // Methods for accessing the minor-to-major array. + int minor_to_major_size() const { return minor_to_major_.size(); } + int64 minor_to_major(int index) const { return minor_to_major_.at(index); } + Layout& set_minor_to_major(int index, int64 value) { + minor_to_major_.at(index) = value; + return *this; + } + Layout& add_minor_to_major(int64 value) { + minor_to_major_.push_back(value); + return *this; + } + Layout& clear_minor_to_major() { + minor_to_major_.clear(); + return *this; + } + const std::vector& minor_to_major() const { return minor_to_major_; } + std::vector* mutable_minor_to_major() { return &minor_to_major_; } + + // Methods for accessing the tile field. + int tiles_size() const { return tiles_.size(); } + const Tile& tiles(int index) const { return tiles_.at(index); } + Tile* mutable_tiles(int index) { return &tiles_.at(index); } + Tile* add_tiles() { + tiles_.push_back(Tile()); + return &tiles_.back(); + } + Layout& clear_tiles() { + tiles_.clear(); + return *this; + } + const std::vector& tiles() const { return tiles_; } + std::vector* mutable_tiles() { return &tiles_; } + + // Methods for accessing the int64 fields. + int64 max_sparse_elements() const { return max_sparse_elements_; } + Layout& set_max_sparse_elements(int64 value) { + max_sparse_elements_ = value; + return *this; + } + int64 element_size_in_bits() const { return element_size_in_bits_; } + Layout& set_element_size_in_bits(int64 value) { + element_size_in_bits_ = value; + return *this; + } + + void Swap(Layout* other) { + using std::swap; + swap(*this, *other); + } + + void Clear() { + format_ = INVALID_FORMAT; + minor_to_major_.clear(); + max_sparse_elements_ = 0; + element_size_in_bits_ = 0; + } + + private: + // The format of this layout. + Format format_ = INVALID_FORMAT; + + // Sequence of dimension numbers, from minor (fastest varying index) to major + // (slowest varying index). + std::vector minor_to_major_; + + // The maximum number of elements that can be stored for SPARSE formats. This + // can be used to determine the maximum size in bytes of arrays stored in + // memory. This field must be zero unless the format is SPARSE. + int64 max_sparse_elements_ = 0; + + // The tiles used in tiling-based layout. + std::vector tiles_; + + // The number of bits used to store an individual array element. + int64 element_size_in_bits_ = 0; +}; + +std::ostream& operator<<(std::ostream& out, const Tile& Tile); +std::ostream& operator<<(std::ostream& out, const Layout& layout); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LAYOUT_H_ diff --git a/tensorflow/compiler/xla/layout_test.cc b/tensorflow/compiler/xla/layout_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5d71c553ed2e0cfd5d5945144dd476557582b5f --- /dev/null +++ b/tensorflow/compiler/xla/layout_test.cc @@ -0,0 +1,116 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/layout.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class LayoutTest : public ::testing::Test {}; + +TEST_F(LayoutTest, ToString) { + EXPECT_EQ(Layout().ToString(), "invalid{}"); + EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); + EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(123).ToString(), + "sparse{123}"); + EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); + EXPECT_EQ(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}).ToString(), + "{3,2,1,0:T(42,123)(4,5)}"); + EXPECT_EQ( + Layout({1, 0}, {Tile({2, 55})}).set_element_size_in_bits(42).ToString(), + "{1,0:T(2,55)E(42)}"); + EXPECT_EQ( + Layout({1, 0}, {Tile({-2, 55})}).set_element_size_in_bits(42).ToString(), + "{1,0:T(Invalid value -2,55)E(42)}"); +} + +TEST_F(LayoutTest, StreamOut) { + { + std::ostringstream oss; + oss << Tile({7, 8}); + EXPECT_EQ(oss.str(), "(7,8)"); + } + + { + std::ostringstream oss; + oss << Layout({0, 1, 2}); + EXPECT_EQ(oss.str(), "{0,1,2}"); + } +} + +TEST_F(LayoutTest, SparseLayoutMaxElements) { + EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), + 101); +} + +TEST_F(LayoutTest, Equality) { + EXPECT_EQ(Layout(), Layout()); + const std::vector empty_dims; + EXPECT_EQ(Layout(empty_dims), Layout(empty_dims)); + EXPECT_NE(Layout(), Layout(empty_dims)); + EXPECT_EQ(Layout({0, 1, 2, 3}), Layout({0, 1, 2, 3})); + EXPECT_NE(Layout({0, 1, 2, 3}), Layout({0, 1, 2})); + EXPECT_EQ(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}, {Tile({42, 44})})); + EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}, {Tile({42, 45})})); + EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2, 3})); + EXPECT_EQ(Layout({0, 1, 2}).set_element_size_in_bits(33), + Layout({0, 1, 2}).set_element_size_in_bits(33)); + EXPECT_NE(Layout({0, 1, 2}).set_element_size_in_bits(33), + Layout({0, 1, 2}).set_element_size_in_bits(7)); + EXPECT_EQ(Layout().set_format(SPARSE), Layout().set_format(SPARSE)); + EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(42), + Layout().set_format(SPARSE).set_max_sparse_elements(42)); + EXPECT_NE(Layout().set_format(SPARSE).set_max_sparse_elements(42), + Layout().set_format(SPARSE).set_max_sparse_elements(24)); + + EXPECT_FALSE( + Layout::Equal()(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2}))); + EXPECT_TRUE(Layout::Equal().IgnoreTiles()(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}))); + EXPECT_FALSE( + Layout::Equal()(Layout({0, 1, 2}, {}, 32), Layout({0, 1, 2}, {}, 1))); + EXPECT_TRUE(Layout::Equal().IgnoreElementSize()(Layout({0, 1, 2}, {}, 32), + Layout({0, 1, 2}, {}, 1))); +} + +TEST_F(LayoutTest, LayoutToFromProto) { + // Round-trips a Layout through proto de/serialization. + auto expect_unchanged = [](const Layout& layout) { + EXPECT_EQ(layout, Layout::CreateFromProto(layout.ToProto())); + }; + + expect_unchanged(Layout()); + expect_unchanged(Layout({1, 3, 2, 0})); + expect_unchanged(Layout().set_format(SPARSE)); + expect_unchanged(Layout().set_format(SPARSE).set_max_sparse_elements(123)); + expect_unchanged(Layout({0, 1}).set_element_size_in_bits(42)); + expect_unchanged(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index dbb81381acde645f08639737b6e7b6f6ad971f9b..62314118ca9713a04cb4e3cf6ad261b966d85f15 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -41,27 +41,37 @@ namespace { // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets // minor_to_major to the value that represents the default layout. -void SetDefaultLayoutToContainer( - tensorflow::protobuf::RepeatedField* - minor_to_major) { +void SetDefaultLayoutToContainer(std::vector* minor_to_major) { // The default XLA layout is major-to-minor (dim 0 is major). // For more information on XLA layouts, see: // https://www.tensorflow.org/performance/xla/shapes const int64 size = minor_to_major->size(); for (int64 i = 0; i < size; ++i) { - minor_to_major->Set(i, size - 1 - i); + (*minor_to_major)[i] = size - 1 - i; } } } // namespace /* static */ Layout LayoutUtil::MakeLayout( - absl::Span minor_to_major) { + absl::Span minor_to_major, absl::Span tiles, + int64 element_size_in_bits) { Layout layout; layout.set_format(DENSE); for (int64 dimension_number : minor_to_major) { layout.add_minor_to_major(dimension_number); } + for (Tile tile : tiles) { + for (int64 dim : tile.dimensions()) { + if (dim < 0 && dim != Tile::kCombineDimension) { + LOG(FATAL) << "Tile dimension size needs to be mininum int64 value if " + "it's negative. Value is " + << dim; + } + } + *layout.add_tiles() = tile; + } + layout.set_element_size_in_bits(element_size_in_bits); return layout; } @@ -94,9 +104,8 @@ namespace { Layout CreateDefaultLayoutForRank(int64 rank) { Layout layout; layout.set_format(DENSE); - tensorflow::protobuf::RepeatedField* - minor_to_major = layout.mutable_minor_to_major(); - minor_to_major->Resize(rank, 0); + std::vector* minor_to_major = layout.mutable_minor_to_major(); + minor_to_major->resize(rank, 0); SetDefaultLayoutToContainer(minor_to_major); return layout; } @@ -104,13 +113,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } // namespace /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) { - if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) { + if (shape.IsOpaque() || shape.IsToken()) { // Opaque and token types have empty layouts. return Layout(); } // A Layout proto corresponds to a single array, not a tuple. - CHECK(ShapeUtil::IsArray(shape)); + CHECK(shape.IsArray()); return CreateDefaultLayoutForRank(shape.dimensions_size()); } @@ -131,17 +140,16 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) { - if (ShapeUtil::IsTuple(*shape)) { + if (shape->IsTuple()) { // Tuple shape. for (auto& element_shape : *shape->mutable_tuple_shapes()) { SetToDefaultLayout(&element_shape); } shape->clear_layout(); - } else if (ShapeUtil::IsArray(*shape)) { + } else if (shape->IsArray()) { shape->mutable_layout()->set_format(DENSE); - tensorflow::protobuf::RepeatedField* - minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); - minor_to_major->Resize(shape->dimensions_size(), 0); + auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); + minor_to_major->resize(shape->dimensions_size(), 0); SetDefaultLayoutToContainer(minor_to_major); } else { // Opaque, token types etc. have no layout. @@ -164,7 +172,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ Status LayoutUtil::ValidateLayoutInShape( const Shape& shape, bool allow_missing_layouts) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { // Tuple shape. if (shape.has_layout()) { return InvalidArgument("tuple should not have a layout field"); @@ -174,7 +182,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { ValidateLayoutInShape(element_shape, allow_missing_layouts)); } return Status::OK(); - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { if (!shape.has_layout()) { if (allow_missing_layouts) { return Status::OK(); @@ -196,11 +204,11 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout, const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } - if (!ShapeUtil::IsArray(shape)) { + if (!shape.IsArray()) { if (layout.minor_to_major_size() != 0) { return InvalidArgument( "shape of primitive type %s should not have a non-trivial layout", @@ -210,25 +218,24 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) { - return InvalidArgument( - "Layout has an invalid format (%d) in layout {%s}, shape {%s}", - layout.format(), layout.ShortDebugString(), shape.ShortDebugString()); + return InvalidArgument("Layout has an invalid format (%d)", + layout.format()); } if (layout.format() == DENSE) { - if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { + if (layout.minor_to_major_size() != shape.rank()) { return InvalidArgument( "layout minor_to_major field contains %d elements, " "but shape is rank %d: {%s}; shape: %s", - layout.minor_to_major_size(), ShapeUtil::Rank(shape), + layout.minor_to_major_size(), shape.rank(), absl::StrJoin(layout.minor_to_major(), ", "), shape.ShortDebugString()); } - std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); - for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) { + std::vector dimensions_in_layout(shape.rank(), false); + for (int64 i = 0; i < shape.rank(); ++i) { int64 dim = layout.minor_to_major(i); - if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { + if (dim < 0 || dim >= shape.rank()) { return InvalidArgument( "layout minor_to_major field has out-of-bounds value: %s", HumanString(layout)); @@ -240,6 +247,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } dimensions_in_layout[dim] = true; } + } else { + if (layout.tiles_size() != 0) { + return InvalidArgument("Only dense layouts can be tiled."); + } } return Status::OK(); @@ -260,8 +271,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) { - return ShapeUtil::IsArray(shape) && shape.has_layout() && - IsDense(shape.layout()); + return shape.IsArray() && shape.has_layout() && IsDense(shape.layout()); } /* static */ bool LayoutUtil::IsDense(const Layout& layout) { @@ -281,8 +291,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) { - return ShapeUtil::IsArray(shape) && shape.has_layout() && - IsSparse(shape.layout()); + return shape.IsArray() && shape.has_layout() && IsSparse(shape.layout()); } /* static */ bool LayoutUtil::IsSparse(const Layout& layout) { @@ -295,11 +304,11 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::HasLayout(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { // Tuple shape: all subshapes must have a layout. - return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), - [](const Shape& s) { return HasLayout(s); }); - } else if (!ShapeUtil::IsArray(shape)) { + return absl::c_all_of(shape.tuple_shapes(), + [](const Shape& s) { return HasLayout(s); }); + } else if (!shape.IsArray()) { // Opaque, token types etc. ignore layout. return true; } @@ -316,7 +325,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) { - return protobuf_util::ProtobufEquals(lhs, rhs); + return lhs == rhs; } /* static */ absl::Span LayoutUtil::MinorToMajor( @@ -358,22 +367,18 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ string LayoutUtil::HumanString(const Layout& layout) { - if (IsSparse(layout)) { - return absl::StrCat("sparse{", layout.max_sparse_elements(), "}"); - } - CHECK(IsDense(layout)); - return absl::StrCat("{", absl::StrJoin(layout.minor_to_major(), ","), "}"); + return layout.ToString(); } namespace { // Internal helper for recursively copying layouts. Status CopyLayoutInternal(const Shape& src, Shape* dst) { - if (ShapeUtil::IsTuple(src) != ShapeUtil::IsTuple(*dst)) { + if (src.IsTuple() != dst->IsTuple()) { return InvalidArgument( "cannot copy layout from shape: shape structure differs"); } - if (ShapeUtil::IsTuple(src)) { + if (src.IsTuple()) { if (ShapeUtil::TupleElementCount(src) != ShapeUtil::TupleElementCount(*dst)) { return InvalidArgument( @@ -385,7 +390,7 @@ Status CopyLayoutInternal(const Shape& src, Shape* dst) { } } else { if (src.has_layout()) { - if (ShapeUtil::Rank(src) != ShapeUtil::Rank(*dst)) { + if (src.rank() != dst->rank()) { return InvalidArgument("cannot copy layout from shape: ranks differs"); } TF_RETURN_IF_ERROR( @@ -407,9 +412,9 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) { - if (ShapeUtil::IsTuple(lhs)) { - if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) != - ShapeUtil::TupleElementCount(rhs)) { + if (lhs.IsTuple()) { + if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) != + ShapeUtil::TupleElementCount(rhs)) { return false; } for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) { @@ -418,8 +423,8 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { } } return true; - } else if (ShapeUtil::IsArray(lhs)) { - return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && + } else if (lhs.IsArray()) { + return lhs.rank() == rhs.rank() && LayoutUtil::Equal(lhs.layout(), rhs.layout()); } else { // Layouts of non-array and non-tuple shapes is ignored. @@ -435,7 +440,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { positions_in_layout.push_back( PositionInContainer(layout.minor_to_major(), dim)); } - std::sort(positions_in_layout.begin(), positions_in_layout.end()); + absl::c_sort(positions_in_layout); for (size_t i = 1; i < positions_in_layout.size(); ++i) { if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) { return false; @@ -444,11 +449,6 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { return true; } -std::ostream& operator<<(std::ostream& out, const Layout& layout) { - out << LayoutUtil::HumanString(layout); - return out; -} - /*static*/ size_t LayoutUtil::Hash(const Layout& layout) { using tensorflow::hash; using tensorflow::Hash64Combine; diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6c298e57252449ce3f1f9055436e918f2d9f17f1..9997aef465daa48ee77050e03d97cde0ea2425cc 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" @@ -35,7 +36,9 @@ class LayoutUtil { public: // Creates a layout with the given minor-to-major dimension order. (This is a // convenience function for protobuf construction.) - static Layout MakeLayout(absl::Span minor_to_major); + static Layout MakeLayout(absl::Span minor_to_major, + absl::Span tiles = {}, + int64 element_size_in_bits = 0); // Similar to MakeLayout, but take indices in reverse order. static Layout MakeLayoutFromMajorToMinor( @@ -195,8 +198,6 @@ class LayoutUtil { TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil); }; -std::ostream& operator<<(std::ostream& out, const Layout& layout); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_ diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 12ce2d2d7c6fa8c590035f9ff2af50001ccf80d8..12da214063676717aa075e66aa54974f4cc2b31b 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -317,15 +317,79 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); } -TEST_F(LayoutUtilTest, SparseLayoutMaxElements) { - EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), - 101); -} - -TEST_F(LayoutUtilTest, StreamOut) { - std::ostringstream oss; - oss << LayoutUtil::MakeLayout({0, 1, 2}); - EXPECT_EQ(oss.str(), "{0,1,2}"); +TEST_F(LayoutUtilTest, HumanStringWithTiling) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3, 4}, {0, 1, 2}); + Tile* tile; + + // No tiling. + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), "f32[2,3,4]{0,1,2}"); + + // 2D tile. + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(512); + tile->add_dimensions(1024); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "f32[2,3,4]{0,1,2:T(512,1024)}"); + + // 1D tile. + shape.mutable_layout()->clear_tiles(); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(512); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "f32[2,3,4]{0,1,2:T(512)}"); + + // 2 tiles. + shape = ShapeUtil::MakeShapeWithLayout(BF16, {2, 3, 4}, {1, 2, 0}); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(16); + tile->add_dimensions(256); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(2); + tile->add_dimensions(1); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "bf16[2,3,4]{1,2,0:T(16,256)(2,1)}"); + + // PRED with element size of 8 bits. + shape = ShapeUtil::MakeShapeWithLayout(PRED, {8, 8, 8}, {0, 2, 1}); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(8); + tile->add_dimensions(128); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "pred[8,8,8]{0,2,1:T(8,128)}"); + + // PRED with element size of 32 bits. + shape.mutable_layout()->clear_tiles(); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(8); + tile->add_dimensions(128); + shape.mutable_layout()->set_element_size_in_bits(32); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "pred[8,8,8]{0,2,1:T(8,128)E(32)}"); + + // No tile. PRED with element size of 32 bits. + shape.mutable_layout()->clear_tiles(); + shape.mutable_layout()->set_element_size_in_bits(32); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "pred[8,8,8]{0,2,1:E(32)}"); + + // Tile with negative dimension size for combining dimensions. + shape = ShapeUtil::MakeShapeWithLayout(BF16, {2, 3, 1004}, {2, 1, 0}); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(2); + tile->add_dimensions(Tile::kCombineDimension); + tile->add_dimensions(128); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "bf16[2,3,1004]{2,1,0:T(2,*,128)}"); + + // Tile with two negative dimensions. + shape = ShapeUtil::MakeShapeWithLayout(BF16, {8, 2, 3, 1004}, {3, 2, 1, 0}); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(2); + tile->add_dimensions(Tile::kCombineDimension); + tile->add_dimensions(Tile::kCombineDimension); + tile->add_dimensions(128); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "bf16[8,2,3,1004]{3,2,1,0:T(2,*,*,128)}"); } TEST_F(LayoutUtilTest, ValidateLayout_ValidArrayLayout) { diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 8f480c1f1079b4e1a5be53958ebdf6e004ad9ebe..5cd738d0f7769ceac7eb3bdbc5abd3196d9cf99c 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -29,10 +29,12 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -42,7 +44,6 @@ namespace xla { namespace { using absl::StrCat; -using absl::StrFormat; constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; @@ -107,7 +108,7 @@ Literal::Literal(const Shape& shape) : Literal(shape, /*allocate_arrays=*/true) {} void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { const Shape& subshape = shape.tuple_shapes(i); @@ -118,7 +119,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { piece->emplace_back(std::move(child_piece)); } - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { if (allocate_arrays) { if (LayoutUtil::IsSparseArray(shape)) { // For sparse arrays, the buffer must be of the size of the maximum @@ -129,7 +130,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); piece->set_sparse_indices( - new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); + new SparseIndexArray(max_sparse_elements, shape.rank())); } else { piece->set_buffer(new char[piece->size_bytes()]); } @@ -187,7 +188,7 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) { Literal literal(shape); literal.root_piece_->ForEachMutableSubpiece( [&](const ShapeIndex& index, Piece* piece) { - if (ShapeUtil::IsArray(piece->subshape())) { + if (piece->subshape().IsArray()) { memset(piece->untyped_data(), 0, piece->size_bytes()); } }); @@ -208,16 +209,15 @@ template Status MutableLiteralBase::CopySliceFromInternal( const LiteralBase& src_literal, absl::Span src_base, absl::Span dest_base, absl::Span copy_size) { - TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); - TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); + TF_RET_CHECK(src_literal.shape().rank() == src_base.size()); + TF_RET_CHECK(shape().rank() == dest_base.size()); auto linear_index = [](const Shape& shape, absl::Span multi_index) { return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); }; - if (ShapeUtil::Rank(src_literal.shape()) == 0 || - ShapeUtil::Rank(shape()) == 0) { + if (src_literal.shape().rank() == 0 || shape().rank() == 0) { // If any of the two shapes are scalars, we can just call the StridedCopy() // directly, and we know we will be copying only one value. TF_RET_CHECK(copy_size.empty()); @@ -312,7 +312,7 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, proto_element = &proto_element->tuple_literals(i); } - if (ShapeUtil::IsTuple(piece->subshape())) { + if (piece->subshape().IsTuple()) { if (proto_element->tuple_literals_size() != ShapeUtil::TupleElementCount(piece->subshape())) { return InvalidArgument( @@ -326,7 +326,7 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } - CHECK(ShapeUtil::IsArray(piece->subshape())); + CHECK(piece->subshape().IsArray()); TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); return Status::OK(); @@ -336,7 +336,7 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, } std::vector Literal::DecomposeTuple() { - CHECK(ShapeUtil::IsTuple(shape())); + CHECK(shape().IsTuple()); std::vector elements; for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), @@ -375,7 +375,7 @@ void CopyElementsBetween(absl::Span dest, if (ShapeUtil::IsZeroElementArray(dest_shape)) { return; } - std::vector index(ShapeUtil::Rank(dest_shape)); + std::vector index(dest_shape.rank()); do { dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; @@ -392,7 +392,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { memcpy(buffer(), src.buffer(), src.size_bytes()); } else { TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape())); - std::vector origin(ShapeUtil::Rank(subshape()), 0); + std::vector origin(subshape().rank(), 0); switch (subshape().element_type()) { #define COPY_ELEMENTS(XLA_T, NATIVE_T) \ case (XLA_T): \ @@ -412,6 +412,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { COPY_ELEMENTS(F32, float); COPY_ELEMENTS(F64, double); COPY_ELEMENTS(C64, complex64); + COPY_ELEMENTS(C128, complex128); COPY_ELEMENTS(PRED, bool); #undef COPY_ELEMENTS default: @@ -438,7 +439,7 @@ Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal, } return root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { - if (!ShapeUtil::IsArray(piece->subshape())) { + if (!piece->subshape().IsArray()) { return Status::OK(); } @@ -477,7 +478,7 @@ Status Literal::MoveFrom(Literal&& src_literal, src_literal.root_piece_->ForEachSubpiece( [&](const ShapeIndex& src_index, const Piece& src_piece) { - if (!ShapeUtil::IsArray(src_piece.subshape())) { + if (!src_piece.subshape().IsArray()) { return; } @@ -504,8 +505,8 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, absl::Span src_base, absl::Span dest_base, absl::Span copy_size) { - TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); - TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) + TF_RET_CHECK(shape().IsArray()) << ShapeUtil::HumanString(shape()); + TF_RET_CHECK(src_literal.shape().IsArray()) << ShapeUtil::HumanString(src_literal.shape()); TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); @@ -549,6 +550,9 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, case C64: return CopySliceFromInternal(src_literal, src_base, dest_base, copy_size); + case C128: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case PRED: return CopySliceFromInternal(src_literal, src_base, dest_base, copy_size); @@ -562,8 +566,8 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, } void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK(shape().IsArray()); + CHECK_EQ(shape().rank(), 1); CHECK_EQ(element_count(), values.bits()); CHECK_EQ(shape().element_type(), PRED); for (int64 i = 0; i < static_cast(values.bits()); ++i) { @@ -592,7 +596,7 @@ Literal LiteralBase::Relayout(const Shape& shape_with_layout) const { ShapeUtil::ForEachSubshape( result.shape(), [this, &result](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { TF_CHECK_OK(result.CopyFrom(*this, /*dest_shape_index=*/index, /*src_shape_index=*/index)); @@ -603,7 +607,7 @@ Literal LiteralBase::Relayout(const Shape& shape_with_layout) const { StatusOr LiteralBase::Broadcast( const Shape& result_shape, absl::Span dimensions) const { - if (!ShapeUtil::IsArray(shape())) { + if (!shape().IsArray()) { return InvalidArgument("Broadcast only supports arrays."); } @@ -643,13 +647,12 @@ StatusOr LiteralBase::Broadcast( StatusOr LiteralBase::Reshape( absl::Span dimensions) const { - if (!ShapeUtil::IsArray(shape())) { + if (!shape().IsArray()) { return InvalidArgument("Reshape does not support tuples."); } Literal output; if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { - output = - Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); + output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank())); } else { output = Clone(); } @@ -671,8 +674,8 @@ StatusOr LiteralBase::Reshape( } Literal LiteralBase::Transpose(absl::Span permutation) const { - CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; - CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) + CHECK(shape().IsArray()) << "Tuple is not supported for transpose"; + CHECK(IsPermutation(permutation, shape().rank())) << "Given permutation is not a permutation of dimension numbers"; // To transpose the array, we just permute the dimensions and layout, and // do a straight memory copy of the raw data set. @@ -711,10 +714,10 @@ template Literal LiteralBase::SliceInternal( const Shape& result_shape, absl::Span start_indices) const { Literal result_literal(result_shape); - DimensionVector new_indices(ShapeUtil::Rank(result_shape)); + DimensionVector new_indices(result_shape.rank()); result_literal.EachCell( [&](absl::Span indices, NativeT /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + for (int64 i = 0; i < result_shape.rank(); ++i) { new_indices[i] = indices[i] + start_indices[i]; } NativeT value = Get(new_indices); @@ -725,10 +728,10 @@ Literal LiteralBase::SliceInternal( Literal LiteralBase::Slice(absl::Span start_indices, absl::Span limit_indices) const { - CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; + CHECK(shape().IsArray()) << "tuple is not supported for slice"; DimensionVector result_dimensions; - for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { + for (int64 dnum = 0; dnum < shape().rank(); ++dnum) { CHECK_GE(start_indices[dnum], 0); CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)) << "dnum = " << dnum; @@ -768,6 +771,8 @@ Literal LiteralBase::Slice(absl::Span start_indices, return SliceInternal(result_shape, start_indices); case C64: return SliceInternal(result_shape, start_indices); + case C128: + return SliceInternal(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); @@ -816,6 +821,10 @@ string LiteralBase::GetAsString(absl::Span multi_index, complex64 c = Get(multi_index, shape_index); return StrCat("(", c.real(), ", ", c.imag(), ")"); } + case C128: { + complex128 c = Get(multi_index, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } default: LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); } @@ -870,6 +879,11 @@ string LiteralBase::GetSparseElementAsString( GetSparseElement(sparse_element_number, shape_index); return StrCat("(", c.real(), ", ", c.imag(), ")"); } + case C128: { + complex128 c = + GetSparseElement(sparse_element_number, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } default: LOG(FATAL) << "Invalid element type for sparse arrays: " << PrimitiveType_Name(subshape.element_type()); @@ -906,7 +920,7 @@ size_t LiteralBase::Hash() const { ShapeUtil::ForEachSubshape( shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (!ShapeUtil::IsArray(subshape)) { + if (!subshape.IsArray()) { return; } @@ -998,6 +1012,9 @@ void LiteralBase::Piece::SortSparseElements() { case C64: SortSparseElementsInternal(); break; + case C128: + SortSparseElementsInternal(); + break; case F16: SortSparseElementsInternal(); break; @@ -1028,20 +1045,21 @@ string ShapeToString(bool print_layout, const Shape& shape) { } void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces); + bool print_shape, bool print_layout, + std::vector* pieces); void TupleToStringHelper(const LiteralBase& literal, - const ShapeIndex& shape_index, bool print_layout, - std::vector* pieces) { + const ShapeIndex& shape_index, bool print_shape, + bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - pieces->push_back(ShapeToString(print_layout, subshape)); - pieces->push_back(" (\n"); + pieces->push_back("(\n"); std::vector tuple_pieces; for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { ShapeIndex element_index = shape_index; element_index.push_back(i); std::vector element_pieces; - ToStringHelper(literal, element_index, print_layout, &element_pieces); + ToStringHelper(literal, element_index, print_shape, print_layout, + &element_pieces); tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); } pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); @@ -1049,11 +1067,13 @@ void TupleToStringHelper(const LiteralBase& literal, } void SparseArrayToStringHelper(const LiteralBase& literal, - const Shape& subshape, bool print_layout, - std::vector* pieces) { - pieces->push_back(ShapeToString(print_layout, subshape)); + const Shape& subshape, bool print_shape, + bool print_layout, std::vector* pieces) { + if (print_shape) { + pieces->push_back(ShapeToString(print_layout, subshape)); + } pieces->push_back("{"); - int64 rank = ShapeUtil::Rank(subshape); + int64 rank = subshape.rank(); int64 num_elements = literal.sparse_element_count(); for (int64 i = 0; i < num_elements; ++i) { if (i > 0) { @@ -1073,10 +1093,10 @@ void SparseArrayToStringHelper(const LiteralBase& literal, } void DenseArrayToStringHelper(const LiteralBase& literal, - const ShapeIndex& shape_index, bool print_layout, - std::vector* pieces) { + const ShapeIndex& shape_index, bool print_shape, + bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - int64 rank = ShapeUtil::Rank(subshape); + int64 rank = subshape.rank(); std::function dimensions, std::vector*)> to_string_recursive = [&](absl::Span dimensions, @@ -1135,7 +1155,7 @@ void DenseArrayToStringHelper(const LiteralBase& literal, } }; - if (rank > 1) { + if (print_shape) { pieces->push_back(ShapeToString(print_layout, subshape)); pieces->push_back(" "); } @@ -1146,19 +1166,23 @@ void DenseArrayToStringHelper(const LiteralBase& literal, } void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces) { + bool print_shape, bool print_layout, + std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); CHECK(LayoutUtil::HasLayout(literal.shape())); CHECK(LayoutUtil::HasLayout(subshape)); - if (ShapeUtil::IsTuple(subshape)) { - TupleToStringHelper(literal, shape_index, print_layout, pieces); - } else if (ShapeUtil::IsToken(subshape)) { + if (subshape.IsTuple()) { + TupleToStringHelper(literal, shape_index, print_shape, print_layout, + pieces); + } else if (subshape.IsToken()) { pieces->push_back("token"); } else if (LayoutUtil::IsSparseArray(subshape)) { - SparseArrayToStringHelper(literal, subshape, print_layout, pieces); + SparseArrayToStringHelper(literal, subshape, print_shape, print_layout, + pieces); } else { CHECK(LayoutUtil::IsDenseArray(subshape)); - DenseArrayToStringHelper(literal, shape_index, print_layout, pieces); + DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout, + pieces); } } @@ -1169,10 +1193,27 @@ int64 LiteralBase::sparse_element_count() const { return sparse_indices()->index_count(); } -string LiteralBase::ToString(bool print_layout) const { +string LiteralBase::ToString() const { std::vector pieces; CHECK(LayoutUtil::HasLayout(this->shape())); - ToStringHelper(*this, {}, print_layout, &pieces); + ToStringHelper(*this, {}, /*print_shape=*/true, + /*print_layout=*/false, &pieces); + return absl::StrJoin(pieces, ""); +} + +string LiteralBase::ToStringWithoutShape() const { + std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); + ToStringHelper(*this, {}, /*print_shape=*/false, + /*print_layout=*/false, &pieces); + return absl::StrJoin(pieces, ""); +} + +string LiteralBase::ToStringWithLayout() const { + std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); + ToStringHelper(*this, {}, /*print_shape=*/true, + /*print_layout=*/true, &pieces); return absl::StrJoin(pieces, ""); } @@ -1193,7 +1234,7 @@ namespace { template Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal, const ConverterType& converter) { - CHECK(ShapeUtil::IsArray(src_literal.shape())); + CHECK(src_literal.shape().IsArray()); Literal result_literal(ShapeUtil::ChangeElementType( src_literal.shape(), primitive_util::NativeToPrimitiveType())); @@ -1208,7 +1249,24 @@ Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal, } template -Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) { +typename std::enable_if<(std::is_same::value) && + (std::is_same::value || + std::is_same::value), + Literal>::type +ConvertBetweenNativeTypes(const LiteralBase& src_literal) { + auto converter = [](NativeSrcT src) { + return NativeDestT(static_cast(src)); + }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + +template +typename std::enable_if<(!std::is_same::value) || + (!std::is_same::value && + !std::is_same::value), + Literal>::type +ConvertBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast(src); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); @@ -1252,22 +1310,6 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } -template -Literal ConvertToC64(const LiteralBase& src_literal) { - CHECK(ShapeUtil::IsArray(src_literal.shape())); - Literal result_literal( - ShapeUtil::ChangeElementType(src_literal.shape(), C64)); - using NativeSrcT = - typename primitive_util::PrimitiveTypeToNative::type; - absl::Span src_data = src_literal.data(); - absl::Span dest_data = result_literal.data(); - int64 num_elements = src_literal.element_count(); - for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = complex64(static_cast(src_data[i]), 0); - } - return result_literal; -} - template Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); @@ -1297,9 +1339,11 @@ StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, bitcast); CONVERT_IF_TYPES_MATCH(PRED) CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S16) CONVERT_IF_TYPES_MATCH(S32) CONVERT_IF_TYPES_MATCH(S64) CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U16) CONVERT_IF_TYPES_MATCH(U32) CONVERT_IF_TYPES_MATCH(U64) CONVERT_IF_TYPES_MATCH(F16) @@ -1308,10 +1352,15 @@ StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, CONVERT_IF_TYPES_MATCH(BF16) #undef CONVERT_IF_TYPES_MATCH case C64: - if (!bitcast) { - return ConvertToC64(src_literal); + if (bitcast) { + break; } - break; + return ConvertIfTypesMatch(src_literal, false); + case C128: + if (bitcast) { + break; + } + return ConvertIfTypesMatch(src_literal, false); // Other types are not yet supported. default: break; @@ -1324,7 +1373,7 @@ StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, StatusOr ConvertSwitch(const LiteralBase& literal, PrimitiveType primitive_dest_type, bool bitcast) { - TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); + TF_RET_CHECK(literal.shape().IsArray()); if (literal.shape().element_type() == primitive_dest_type) { return literal.Clone(); } @@ -1335,9 +1384,11 @@ StatusOr ConvertSwitch(const LiteralBase& literal, bitcast); CONVERT_IF_DEST_TYPE_MATCHES(PRED) CONVERT_IF_DEST_TYPE_MATCHES(S8) + CONVERT_IF_DEST_TYPE_MATCHES(S16) CONVERT_IF_DEST_TYPE_MATCHES(S32) CONVERT_IF_DEST_TYPE_MATCHES(S64) CONVERT_IF_DEST_TYPE_MATCHES(U8) + CONVERT_IF_DEST_TYPE_MATCHES(U16) CONVERT_IF_DEST_TYPE_MATCHES(U32) CONVERT_IF_DEST_TYPE_MATCHES(U64) CONVERT_IF_DEST_TYPE_MATCHES(F16) @@ -1377,7 +1428,7 @@ StatusOr LiteralBase::BitcastConvert( } StatusOr LiteralBase::ConvertToShape(const Shape& dest_shape) const { - if (!ShapeUtil::IsTuple(dest_shape)) { + if (!dest_shape.IsTuple()) { return Convert(dest_shape.element_type()); } std::vector elements; @@ -1409,7 +1460,7 @@ StatusOr LiteralBase::ConvertToShape(const Shape& dest_shape) const { template bool LiteralBase::Piece::EqualElementsInternal( const LiteralBase::Piece& other, std::vector* multi_index) const { - if (multi_index->size() == ShapeUtil::Rank(subshape())) { + if (multi_index->size() == subshape().rank()) { return (Get(*multi_index) == other.Get(*multi_index)); } for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) { @@ -1459,6 +1510,8 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { return EqualElementsInternal(other, &multi_index); case C64: return EqualElementsInternal(other, &multi_index); + case C128: + return EqualElementsInternal(other, &multi_index); default: LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " << PrimitiveType_Name(subshape().element_type()); @@ -1472,7 +1525,7 @@ bool LiteralBase::operator==(const LiteralBase& other) const { return root_piece().ForEachSubpieceWithBool( [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { + if (!piece.subshape().IsArray()) { return true; } @@ -1502,7 +1555,7 @@ static bool AllElementsEqualValue(absl::Span data, bool LiteralBase::IsAll(int8 value) const { return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { + if (!piece.subshape().IsArray()) { return true; } @@ -1570,30 +1623,24 @@ bool LiteralBase::IsAll(int8 value) const { bool LiteralBase::IsAllFloat(float value) const { return root_piece().ForEachSubpieceWithBool( [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { + if (!piece.subshape().IsArray()) { return true; } - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue( - piece.data(), static_cast(value)); - default: - return false; - } - }; - if (!piece_is_all()) { - return false; + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue( + piece.data(), static_cast(value)); + default: + return false; } - return true; }); } @@ -1602,6 +1649,9 @@ bool LiteralBase::IsAllComplex(complex64 value) const { case C64: return AllElementsEqualValue(root_piece().data(), value); + case C128: + return AllElementsEqualValue(root_piece().data(), + value); default: return false; } @@ -1610,7 +1660,7 @@ bool LiteralBase::IsAllComplex(complex64 value) const { bool LiteralBase::IsAllFirst() const { return root_piece().ForEachSubpieceWithBool( [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { + if (!piece.subshape().IsArray()) { return true; } @@ -1681,6 +1731,11 @@ bool LiteralBase::IsAllFirst() const { auto data = piece.data(); return AllElementsEqualValue(data, data[0]); } + + case C128: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } default: return false; } @@ -1694,11 +1749,11 @@ bool LiteralBase::IsAllFirst() const { } bool LiteralBase::IsR1Iota() const { - if (!ShapeUtil::IsArray(shape())) { + if (!shape().IsArray()) { return false; } - if (ShapeUtil::Rank(shape()) != 1) { + if (shape().rank() != 1) { return false; } @@ -1730,6 +1785,8 @@ bool LiteralBase::IsR1Iota() const { return Get({idx}) == static_cast(idx); case C64: return Get({idx}) == complex64(idx, 0.0f); + case C128: + return Get({idx}) == complex128(idx, 0.0f); case PRED: return Get({idx}) == idx; // token, opaque, tuple, etc. are all not iota. @@ -1749,7 +1806,7 @@ bool LiteralBase::IsR1Iota() const { } bool LiteralBase::IsZero(absl::Span indices) const { - CHECK(ShapeUtil::IsArray(shape())); + CHECK(shape().IsArray()); switch (shape().element_type()) { case U8: return Get(indices) == 0; @@ -1773,6 +1830,8 @@ bool LiteralBase::IsZero(absl::Span indices) const { return Get(indices) == 0.0; case C64: return Get(indices) == complex64(0.0f, 0.0f); + case C128: + return Get(indices) == complex128(0.0f, 0.0f); case F16: return Get(indices) == static_cast(0.0f); case BF16: @@ -1860,6 +1919,12 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { proto->add_c64s(value.imag()); } break; + case C128: + for (complex128 value : data()) { + proto->add_c128s(value.real()); + proto->add_c128s(value.imag()); + } + break; case TUPLE: case TOKEN: // Nothing to do but assign the shape which is done above. @@ -1872,12 +1937,12 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { } const void* LiteralBase::Piece::untyped_data() const { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape()); return buffer(); } void* LiteralBase::Piece::untyped_data() { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape()); return buffer(); } @@ -1908,14 +1973,12 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { if (LayoutUtil::IsSparseArray(subshape())) { // Compute the number of elements (indices) in the sparse shape and reserve // the necessary space in spare_indices. - TF_RET_CHECK(ShapeUtil::Rank(subshape()) != 0) - << "Scalar shapes cannot be sparse"; - TF_RET_CHECK(proto.sparse_indices_size() % ShapeUtil::Rank(subshape()) == 0) + TF_RET_CHECK(subshape().rank() != 0) << "Scalar shapes cannot be sparse"; + TF_RET_CHECK(proto.sparse_indices_size() % subshape().rank() == 0) << "Unexpected number of indices in proto (" << proto.sparse_indices_size() << ") for shape of rank " - << ShapeUtil::Rank(subshape()); - const int64 index_count = - proto.sparse_indices_size() / ShapeUtil::Rank(subshape()); + << subshape().rank(); + const int64 index_count = proto.sparse_indices_size() / subshape().rank(); sparse_indices()->Resize(index_count); // Copy the indices from the proto into the SparseIndexArray object. @@ -1994,7 +2057,17 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { for (int64 i = 0; i < complex_data.size(); ++i) { complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)}; } - } break; + break; + } + case C128: { + auto complex_data = data(); + TF_RET_CHECK(proto.c128s_size() == complex_data.size() * 2); + for (int64 i = 0; i < complex_data.size(); ++i) { + complex_data[i] = + complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)}; + } + break; + } case TUPLE: return InvalidArgument("Should not be called on tuple shapes: %s", ShapeUtil::HumanString(subshape())); @@ -2040,8 +2113,8 @@ int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { } string LiteralBase::GetR1U8AsString() const { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK(shape().IsArray()); + CHECK_EQ(shape().rank(), 1); CHECK_EQ(shape().element_type(), U8); return string(absl::bit_cast(data().data()), ShapeUtil::ElementsIn(shape())); @@ -2055,7 +2128,7 @@ void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape, << ShapeUtil::HumanString(src_piece->subshape()) << "dest_piece has shape: " << ShapeUtil::HumanString(dest_piece->subshape()); - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { const Shape& subshape = shape.tuple_shapes(i); @@ -2066,7 +2139,7 @@ void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape, dest_piece->emplace_back(std::move(child_piece)); } - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { dest_piece->set_buffer(src_piece->buffer()); } else { // If the shape is neither an array nor tuple, then it must be @@ -2142,7 +2215,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, : MutableLiteralBase() { shape_ = absl::make_unique(shape); CHECK(LayoutUtil::HasLayout(*shape_)); - CHECK(!ShapeUtil::IsTuple(*shape_)); + CHECK(!shape_->IsTuple()); root_piece_ = new Piece(); root_piece_->set_buffer(const_cast(src_buf_ptr)); @@ -2169,14 +2242,14 @@ LiteralSlice::LiteralSlice(const LiteralBase& literal, : LiteralBase(), root_piece_(&literal.piece(view_root)) {} void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { - CHECK(ShapeUtil::IsTuple(shape)); + CHECK(shape.IsTuple()); for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { const Shape& subshape = shape.tuple_shapes(i); auto child_piece = Piece(); child_piece.set_subshape(&subshape); - if (ShapeUtil::IsTuple(subshape)) { + if (subshape.IsTuple()) { BuildPieceSubtree(subshape, &child_piece); } @@ -2186,7 +2259,7 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) : LiteralBase(), shape_(absl::make_unique(shape)) { - CHECK(ShapeUtil::IsArray(*shape_)); + CHECK(shape_->IsArray()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = Piece(); @@ -2197,7 +2270,7 @@ BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) BorrowingLiteral::BorrowingLiteral(absl::Span src_buf_ptrs, const Shape& shape) : LiteralBase(), shape_(absl::make_unique(shape)) { - CHECK(ShapeUtil::IsTuple(*shape_)); + CHECK(shape_->IsTuple()); CHECK(!ShapeUtil::IsNestedTuple(*shape_)); CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); root_piece_ = Piece(); @@ -2206,7 +2279,7 @@ BorrowingLiteral::BorrowingLiteral(absl::Span src_buf_ptrs, for (int i = 0; i < src_buf_ptrs.size(); ++i) { const auto& src_shape = shape_->tuple_shapes(i); - CHECK(ShapeUtil::IsArray(src_shape)); + CHECK(src_shape.IsArray()); root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); } } diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index fa9a71af4ceb998a7a289443cbef70eb52cb1a11..c418be895d6c3faa6a85ca2c73c6f42b0a021104 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -92,9 +92,20 @@ class LiteralBase { // array. string GetR1U8AsString() const; - // Returns a string representation of the literal value. - // Warning: this function can take minutes for multi-million element Literals. - string ToString(bool print_layout = false) const; + // Returns a string representation of the literal value. The Shape of the + // literal is a prefix of the literal value in the string. + + // Warning: this function can take minutes for multi-million + // element Literals. + string ToString() const; + + // Returns a string representation of the literal value which does *not* + // include the shape string. + string ToStringWithoutShape() const; + + // Returns a string representation of the literal value which includes the + // shape string with its layout.does *not* include the shape string. + string ToStringWithLayout() const; // Gets an element in the literal at the given index. The multi_index is // CHECKed against the dimension sizes. @@ -856,7 +867,7 @@ class BorrowingLiteral : public LiteralBase { template absl::Span LiteralBase::Piece::data() const { - DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape()); DCHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) << "Attempting to access " @@ -869,7 +880,7 @@ absl::Span LiteralBase::Piece::data() const { template absl::Span LiteralBase::Piece::data() { - DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape()); DCHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) << "Attempting to access " @@ -950,8 +961,12 @@ void MutableLiteralBase::AppendSparseElement( Piece& p = piece(shape_index); const Shape& subshape = p.subshape(); CHECK(LayoutUtil::IsSparseArray(subshape)); - int64 rank = ShapeUtil::Rank(subshape); + int64 rank = subshape.rank(); CHECK_EQ(multi_index.size(), rank); + for (int64 i = 0; i < rank; ++i) { + CHECK_GE(multi_index[i], 0); + CHECK_LT(multi_index[i], subshape.dimensions(i)); + } int64 last_element = p.sparse_indices()->index_count(); CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); p.sparse_indices()->Append(multi_index); @@ -966,7 +981,7 @@ void LiteralBase::EachCell( if (ShapeUtil::IsZeroElementArray(shape())) { return; } - std::vector indices(ShapeUtil::Rank(shape()), 0); + std::vector indices(shape().rank(), 0); do { per_cell(indices, Get(indices)); } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices))); @@ -974,8 +989,8 @@ void LiteralBase::EachCell( template inline void MutableLiteralBase::PopulateR1(absl::Span values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK(shape().IsArray()); + CHECK_EQ(shape().rank(), 1); CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); @@ -986,8 +1001,8 @@ inline void MutableLiteralBase::PopulateR1(absl::Span values) { template void MutableLiteralBase::PopulateR2( std::initializer_list> values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 2); + CHECK(shape().IsArray()); + CHECK_EQ(shape().rank(), 2); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); @@ -1010,10 +1025,10 @@ void MutableLiteralBase::PopulateR2( template void MutableLiteralBase::PopulateFromArray(const Array& values) { - CHECK(ShapeUtil::IsArray(shape())); + CHECK(shape().IsArray()); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); - CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions()); + CHECK_EQ(shape().rank(), values.num_dimensions()); for (int dim = 0; dim < values.num_dimensions(); ++dim) { CHECK_EQ(values.dim(dim), shape().dimensions(dim)); } @@ -1042,7 +1057,7 @@ void MutableLiteralBase::PopulateSparse(SparseIndexArray indices, absl::Span values, bool sort) { CHECK(LayoutUtil::IsSparseArray(shape())); - int rank = ShapeUtil::Rank(shape()); + int rank = shape().rank(); CHECK_EQ(indices.rank(), rank); int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); CHECK_LE(indices.max_indices(), max_elements); @@ -1066,7 +1081,7 @@ template Status MutableLiteralBase::PopulateInternal(const FnType& generator, bool parallel) { const Shape& this_shape = shape(); - const int64 rank = ShapeUtil::Rank(this_shape); + const int64 rank = this_shape.rank(); TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); @@ -1118,7 +1133,7 @@ Status MutableLiteralBase::PopulateParallel(const FnType& generator) { template void MutableLiteralBase::PopulateWithValue(NativeT value) { - CHECK(ShapeUtil::IsArray(shape())); + CHECK(shape().IsArray()); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); for (NativeT& element : data()) { diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index b044f0ad73f13a0599e77f1f43888bc974e31f73..9b3de75dd4e9d495778af86fb8fc07909ab4ba81 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -46,68 +46,116 @@ uint16 GetRawValue(Eigen::half val) { return val.x; } // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template -Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, - absl::Span multi_index) { +bool CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, + absl::Span multi_index) { + auto ulhs = absl::bit_cast(GetRawValue(lhs)); + auto urhs = absl::bit_cast(GetRawValue(rhs)); + return ulhs == urhs; +} + +// Templated comparator that specializes for float equality comparison with the +// bitwise helper above (this is the un-specialized fallback, to just use the +// default gunit implementation). +template +bool CompareEqual(NativeT lhs, NativeT rhs, + absl::Span multi_index) { + return lhs == rhs; +} + +// Specializations for floating types that do bitwise comparisons when equality +// comparison is requested. +template <> +bool CompareEqual(bfloat16 lhs, bfloat16 rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(Eigen::half lhs, Eigen::half rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(float lhs, float rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(double lhs, double rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(complex64 lhs, complex64 rhs, + absl::Span multi_index) { + return CompareEqual(lhs.real(), rhs.real(), multi_index) && + CompareEqual(lhs.imag(), rhs.imag(), multi_index); +} +template <> +bool CompareEqual(complex128 lhs, complex128 rhs, + absl::Span multi_index) { + return CompareEqual(lhs.real(), rhs.real(), multi_index) && + CompareEqual(lhs.imag(), rhs.imag(), multi_index); +} + +template +Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs, + absl::Span multi_index) { auto ulhs = absl::bit_cast(GetRawValue(lhs)); auto urhs = absl::bit_cast(GetRawValue(rhs)); auto lhs_double = static_cast(lhs); auto rhs_double = static_cast(rhs); - if (ulhs != urhs) { return InvalidArgument( "floating values are not bitwise-equal; and equality testing " "was requested: %s=%g=%a vs %s=%g=%a at array index %s", StrCat(absl::Hex(ulhs)), lhs_double, lhs_double, StrCat(absl::Hex(urhs)), rhs_double, rhs_double, LiteralUtil::MultiIndexAsString(multi_index)); - } - return Status::OK(); } -// Templated comparator that specializes for float equality comparison with the -// bitwise helper above (this is the un-specialized fallback, to just use the -// default gunit implementation). template -Status CompareEqual(NativeT lhs, NativeT rhs, - absl::Span multi_index) { - if (lhs == rhs) { - return Status::OK(); - } +Status MakeErrorStatus(NativeT lhs, NativeT rhs, + absl::Span multi_index) { return InvalidArgument( "first mismatch at array index %s:\n expected value: %s\n actual " "value: %s", LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs)); } -// Specializations for floating types that do bitwise comparisons when equality -// comparison is requested. template <> -Status CompareEqual(bfloat16 lhs, bfloat16 rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(Eigen::half lhs, Eigen::half rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(Eigen::half lhs, Eigen::half rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(float lhs, float rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(float lhs, float rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(double lhs, double rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(double lhs, double rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(complex64 lhs, complex64 rhs, - absl::Span multi_index) { - auto res = CompareEqual(lhs.real(), rhs.real(), multi_index); - if (!res.ok()) { - return res; +Status MakeErrorStatus(complex64 lhs, complex64 rhs, + absl::Span multi_index) { + if (!CompareEqual(lhs.real(), rhs.real(), multi_index)) { + return MakeErrorStatus(lhs.real(), rhs.real(), multi_index); } - return CompareEqual(lhs.imag(), rhs.imag(), multi_index); + return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index); +} +template <> +Status MakeErrorStatus(complex128 lhs, complex128 rhs, + absl::Span multi_index) { + if (!CompareEqual(lhs.real(), rhs.real(), multi_index)) { + return MakeErrorStatus(lhs.real(), rhs.real(), multi_index); + } + return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index); } // A recursive function which iterates through every index of expected and @@ -119,7 +167,11 @@ Status Equal(LiteralSlice expected, LiteralSlice actual, if (dimension == expected.shape().dimensions_size()) { NativeT expected_value = expected.Get(multi_index); NativeT actual_value = actual.Get(multi_index); - return CompareEqual(expected_value, actual_value, multi_index); + bool result = + CompareEqual(expected_value, actual_value, multi_index); + return result ? Status::OK() + : MakeErrorStatus(expected_value, actual_value, + multi_index); } Status result; @@ -134,53 +186,40 @@ Status Equal(LiteralSlice expected, LiteralSlice actual, // Gets the total element count. For tuples, this is not the count of tuple // elements, but the sum of elements of each tuple element. int64 RecursiveElementCount(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); int64 total = 0; for (int64 i = 0; i < tuple_elements; ++i) { total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); } return total; - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { return ShapeUtil::ElementsIn(shape); } else { return 0; } } -// Returns whether the actual and expected values are mismatched with respect to -// nans. 'relaxed_nans' is interpreted as in xla::ErrorSpec. +// Returns whether the given value is infinity. template -bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { - if (relaxed_nans) { - return !std::isnan(expected) && std::isnan(actual); - } else { - return std::isnan(expected) != std::isnan(actual); - } -} - -template <> -bool NanMismatch(complex64 expected, complex64 actual, - bool relaxed_nans) { - return NanMismatch(expected.real(), actual.real(), relaxed_nans) || - NanMismatch(expected.imag(), actual.imag(), relaxed_nans); +bool IsInf(NativeT val) { + return std::isinf(val); } template <> -bool NanMismatch(half expected, half actual, bool relaxed_nans) { - return NanMismatch(static_cast(expected), - static_cast(actual), relaxed_nans); +bool IsInf(half val) { + return std::isinf(static_cast(val)); } -// Returns whether the given value is infinity. +// Returns whether the given value is nan. template -bool IsInf(NativeT val) { - return std::isinf(val); +float IsNan(NativeT value) { + return std::isnan(value); } template <> -bool IsInf(half val) { - return std::isinf(static_cast(val)); +float IsNan(half value) { + return IsNan(static_cast(value)); } // Converts the given floating-point value to a string. @@ -194,6 +233,11 @@ string FpValueToString(complex64 value) { return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); } +template <> +string FpValueToString(complex128 value) { + return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); +} + // Returns the absolute value of the given floating point value. This function // is used instead of std::abs directly in order to allow type-dependent // implementations for NearComparator. @@ -273,7 +317,7 @@ class NearComparator { // If the shapes mismatch, we simply fail the expectation instead of // printing out data, as it's a type error rather than a value error. TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape())); - if (!ShapeUtil::IsArray(expected_.shape())) { + if (!expected_.shape().IsArray()) { return InvalidArgument("Expected array shape; got %s.", ShapeUtil::HumanString(expected_.shape())); } @@ -326,35 +370,59 @@ class NearComparator { // the given literal_index and keeps track of various mismatch statistics. template void CompareValues(T expected, T actual, int64 linear_index) { - const bool is_nan_mismatch = - NanMismatch(expected, actual, error_.relaxed_nans); float abs_error; float rel_error; - if (CompareEqual(expected, actual, {linear_index}).ok()) { + if (CompareEqual(expected, actual, {linear_index})) { abs_error = 0; rel_error = 0; - } else if (is_nan_mismatch) { - num_nan_mismatches_++; - // A nan mismatch is considered to have infinite error. rel_error is used - // for sorting a std::set of the top mismatchs, and a nan value here will - // result in undefined behavior because nan's do not satisfy the strict - // weak ordering requirement of std containers. - abs_error = std::numeric_limits::infinity(); - rel_error = std::numeric_limits::infinity(); + } else if (IsNan(expected) || IsNan(actual)) { + if ((!error_.relaxed_nans && IsNan(expected) != IsNan(actual)) || + (error_.relaxed_nans && !IsNan(expected) && IsNan(actual))) { + num_nan_mismatches_++; + // A nan mismatch is considered to have infinite error. rel_error is + // used for sorting a std::set of the top mismatchs, and a nan value + // here will result in undefined behavior because nan's do not satisfy + // the strict weak ordering requirement of std containers. + abs_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); + } else { + abs_error = 0; + rel_error = 0; + } + } else if (IsInf(actual) && !IsInf(expected) && error_.fewer_infs_ok) { + // `fewer_infs_ok` gives us the option of comparing as though `actual` + // were float_max/min rather than inf. + T actual_finite = actual > T{0} ? std::numeric_limits::max() + : std::numeric_limits::lowest(); + abs_error = FpAbsoluteValue(actual_finite - expected); + + // Avoid division by 0 even though it's well-defined because ubsan can be + // configured to treat this as a fatal error. + if (expected != T{0}) { + rel_error = abs_error / FpAbsoluteValue(expected); + } else { + rel_error = std::numeric_limits::infinity(); + } } else if (IsInf(expected) || IsInf(actual)) { // If either the expected or actual value is infinity but not both, // then both absolute and relative error are regarded as inifity. - CHECK(!CompareEqual(expected, actual, {linear_index}).ok()); + CHECK(!CompareEqual(expected, actual, {linear_index})); abs_error = std::numeric_limits::infinity(); rel_error = std::numeric_limits::infinity(); } else { abs_error = FpAbsoluteValue(actual - expected); - rel_error = abs_error / FpAbsoluteValue(expected); + + // Avoid division by 0 even though it's well-defined because ubsan can be + // configured to treat this as a fatal error. + if (expected != T{0}) { + rel_error = abs_error / FpAbsoluteValue(expected); + } else { + rel_error = std::numeric_limits::infinity(); + } } const bool is_abs_mismatch = abs_error > error_.abs; const bool is_rel_mismatch = rel_error > error_.rel; - const bool is_mismatch = - is_nan_mismatch || (is_abs_mismatch && is_rel_mismatch); + const bool is_mismatch = is_abs_mismatch && is_rel_mismatch; // Update the error of the relative bucket only if the *absolute* error // bound is exceeded and vice versa. @@ -389,7 +457,7 @@ class NearComparator { mismatches_.data()[linear_index] = true; } - // For complex64 types, we compare real and imaginary parts individually. + // For complex types, we compare real and imaginary parts individually. void CompareValues(complex64 expected, complex64 actual, int64 linear_index) { bool mismatch = false; CompareValues(expected.real(), actual.real(), linear_index); @@ -412,6 +480,29 @@ class NearComparator { mismatches_.data()[linear_index] = mismatch; } + void CompareValues(complex128 expected, complex128 actual, + int64 linear_index) { + bool mismatch = false; + CompareValues(expected.real(), actual.real(), linear_index); + if (mismatches_.data()[linear_index] == true) { + mismatch = true; + // Delay the mismatch count increase for real part, instead increase + // mismatch by 1 for the entire complex number. + num_mismatches_--; + } + CompareValues(expected.imag(), actual.imag(), linear_index); + if (mismatches_.data()[linear_index] == true) { + mismatch = true; + // Delay the mismatch count increase for imag part, instead increase + // mismatch by 1 for the entire complex number. + num_mismatches_--; + } + if (mismatch == true) { + num_mismatches_++; + } + mismatches_.data()[linear_index] = mismatch; + } + // Compares the two literals elementwise. void CompareLiterals() { // Fast path optimization for the case were layouts match. @@ -425,7 +516,7 @@ class NearComparator { } return; } - std::vector multi_index(ShapeUtil::Rank(actual_.shape()), 0); + std::vector multi_index(actual_.shape().rank(), 0); CompareLiteralsSlow(0, &multi_index); } @@ -620,6 +711,9 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { case C64: result = Equal(expected, actual, index, 0); break; + case C128: + result = Equal(expected, actual, index, 0); + break; case TUPLE: { for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { result.Update(EqualHelper(LiteralSlice(expected, {i}), @@ -642,12 +736,12 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { // via recursion. shape_index is the ShapeIndex of expected (or actual) // currently being compared. Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback, const ShapeIndex& shape_index) { TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); - if (ShapeUtil::IsTuple(expected.shape())) { + if (expected.shape().IsTuple()) { Status return_status; for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { const auto expected_element = LiteralSlice(expected, {i}); @@ -683,26 +777,32 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, if (ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())) { + bool use_detailed_message = detailed_message.value_or( + ShapeUtil::ElementsIn(expected.shape()) >= 64); switch (expected.shape().element_type()) { case BF16: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F16: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F32: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F64: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case C64: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); + break; + case C128: + return NearComparator::Compare( + expected, actual, error, use_detailed_message, miscompare_callback); break; default: LOG(FATAL) << "Unsupported primitive type in near comparator: " @@ -723,7 +823,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual)); } - if (ShapeUtil::IsTuple(expected)) { + if (expected.IsTuple()) { if (ShapeUtil::TupleElementCount(expected) != ShapeUtil::TupleElementCount(actual)) { return InvalidArgument( @@ -738,8 +838,8 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { return AppendStatus(result, StrCat("mismatch in tuple index", i)); } } - } else if (ShapeUtil::IsArray(expected)) { - if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { + } else if (expected.IsArray()) { + if (expected.rank() != actual.rank()) { return InvalidArgument("want rank of %s got rank of %s", ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual)); @@ -793,7 +893,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { } Status Near(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback) { VLOG(1) << "Expected literal:"; XLA_VLOG_LINES(1, expected.ToString()); diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h index 9e5bf7c1d062ef0f25d07a80d6ded8106df5dacc..23fff3fa348f1652eaec344da4c40ccf3ad1079a 100644 --- a/tensorflow/compiler/xla/literal_comparison.h +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -55,9 +55,10 @@ using MiscompareCallback = // being compared. // // If detailed_message is true, then the error message in the assertion result -// will contain a more detailed breakdown of mismatches. +// will contain a more detailed breakdown of mismatches. By default, we display +// a detailed message only for "large" inputs. Status Near(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback); // Calling ToString on a literal with over 100 million elements takes around diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 49363ad802ddb9520f89b53257216bc7ddaf8ff5..b54a71ae68218ef578535a913f5867d843236e32 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -98,42 +98,45 @@ class LiteralUtilTest : public ::testing::Test { TEST_F(LiteralUtilTest, LiteralScalarToString) { auto true_lit = LiteralUtil::CreateR0(true); - EXPECT_EQ("true", true_lit.ToString()); + EXPECT_EQ("pred[] true", true_lit.ToString()); auto false_lit = LiteralUtil::CreateR0(false); - EXPECT_EQ("false", false_lit.ToString()); + EXPECT_EQ("pred[] false", false_lit.ToString()); auto u32_lit = LiteralUtil::CreateR0(42); - EXPECT_EQ("42", u32_lit.ToString()); + EXPECT_EQ("u32[] 42", u32_lit.ToString()); auto s32_lit = LiteralUtil::CreateR0(-999); - EXPECT_EQ("-999", s32_lit.ToString()); + EXPECT_EQ("s32[] -999", s32_lit.ToString()); auto f32_lit = LiteralUtil::CreateR0(3.14f); - EXPECT_EQ("3.14", f32_lit.ToString()); + EXPECT_EQ("f32[] 3.14", f32_lit.ToString()); auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", f16_lit.ToString()); + EXPECT_EQ("f16[] 0.5", f16_lit.ToString()); auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); - EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString()); + EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString()); + + auto c128_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); + EXPECT_EQ("c128[] (3.14, 2.78)", c128_lit.ToString()); auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", bf16_lit.ToString()); + EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString()); // 3.14 will be rounded to 3.14062 in bfloat16 format. auto bf16_lit_truncated = LiteralUtil::CreateR0(static_cast(3.14f)); - ASSERT_EQ("3.14062", bf16_lit_truncated.ToString()); + ASSERT_EQ("bf16[] 3.14062", bf16_lit_truncated.ToString()); auto bf16_lit_truncated2 = LiteralUtil::CreateR0(static_cast(9.001f)); - EXPECT_EQ("9", bf16_lit_truncated2.ToString()); + EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - EXPECT_EQ("{1, 0, 1}", pred_vec.ToString()); + EXPECT_EQ("pred[3] {1, 0, 1}", pred_vec.ToString()); } TEST_F(LiteralUtilTest, R2ToString) { @@ -210,8 +213,8 @@ TEST_F(LiteralUtilTest, TupleToString) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); - const string expected = R"((f32[], f32[2,2]) ( -1, + const string expected = R"(( +f32[] 1, f32[2,2] { { 1, 2 }, { 3, 4 } @@ -469,6 +472,21 @@ TEST_F(LiteralUtilTest, C64Equality) { EXPECT_NE(vector, vector_reversed); } +TEST_F(LiteralUtilTest, C128Equality) { + // Test equality with tuples. + auto vector = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + + // Tuple with the same elements. One element is shared with the original + // tuple, the other is a clone of the element in the original tuple. + auto vector_clone = + LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + EXPECT_EQ(vector, vector_clone); + + auto vector_reversed = + LiteralUtil::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); + EXPECT_NE(vector, vector_reversed); +} + TEST_F(LiteralUtilTest, IsAllTuple) { auto element1 = LiteralUtil::CreateR0(0.0); auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); @@ -623,7 +641,7 @@ template class LiteralUtilTestTemplated : public ::testing::Test {}; using TestedTypes = ::testing::Types; -TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); +TYPED_TEST_SUITE(LiteralUtilTestTemplated, TestedTypes); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { // Make a non-integer for floating point types. @@ -836,6 +854,13 @@ TEST_F(LiteralUtilTest, PopulateR1C64) { EXPECT_EQ(output, expected); } +TEST_F(LiteralUtilTest, PopulateR1C128) { + Literal output(ShapeUtil::MakeShape(C128, {1})); + output.PopulateR1({{77, 88}}); + auto expected = LiteralUtil::CreateR1({{77, 88}}); + EXPECT_EQ(output, expected); +} + TEST_F(LiteralUtilTest, PopulateR2C64) { Literal output(ShapeUtil::MakeShape(C64, {2, 2})); output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); @@ -897,6 +922,14 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { EXPECT_EQ(output, expected); } +TEST_F(LiteralUtilTest, PopulateWithValueR2C128) { + Literal output(ShapeUtil::MakeShape(C128, {2, 2})); + output.PopulateWithValue({4, 2}); + auto expected = + LiteralUtil::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); + EXPECT_EQ(output, expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { Literal output(ShapeUtil::MakeShape(F16, {})); half h(0.25f); @@ -1237,11 +1270,21 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); + auto s16 = LiteralUtil::CreateR4WithLayout({{ + {{10, 0, 12, 0}, {0, 15, 0, 17}}, + {{0, 19, 0, 21}, {22, 0, 24, 0}}, + {{26, 0, 28, 0}, {0, 31, 0, 33}}, + }}, layout_r4_dim0major_); auto s32 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); + auto u16 = LiteralUtil::CreateR4WithLayout({{ + {{10, 0, 12, 0}, {0, 15, 0, 17}}, + {{0, 19, 0, 21}, {22, 0, 24, 0}}, + {{26, 0, 28, 0}, {0, 31, 0, 33}}, + }}, layout_r4_dim0major_); auto u32 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, @@ -1298,9 +1341,19 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); - // clang-format on + auto c128 = LiteralUtil::CreateR4WithLayout({{ + {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}}, + {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, + {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, + }}, layout_r4_dim0major_); // clang-format on Literal conv; + conv = s8.Convert(U16).ConsumeValueOrDie(); + EXPECT_EQ(conv, u16); + + conv = s8.Convert(S16).ConsumeValueOrDie(); + EXPECT_EQ(conv, s16); + conv = s8.Convert(U32).ConsumeValueOrDie(); EXPECT_EQ(conv, u32); @@ -1352,12 +1405,26 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { conv = f16.Convert(C64).ConsumeValueOrDie(); EXPECT_EQ(conv, c64); + conv = s32.Convert(S16).ConsumeValueOrDie(); + EXPECT_EQ(conv, s16); + + conv = s32.Convert(U16).ConsumeValueOrDie(); + EXPECT_EQ(conv, u16); + + conv = s32.Convert(C128).ConsumeValueOrDie(); + EXPECT_EQ(conv, c128); + + conv = f16.Convert(C128).ConsumeValueOrDie(); + EXPECT_EQ(conv, c128); + EXPECT_EQ(s32.Convert(TUPLE).status().code(), tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c128.Convert(F32).status().code(), + tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c128.Convert(S32).status().code(), + tensorflow::error::UNIMPLEMENTED); } TEST_F(LiteralUtilTest, BitcastConvert) { @@ -1642,7 +1709,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]})); Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements)); - ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); + ASSERT_TRUE(literal.shape().IsTuple()); ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3); EXPECT_EQ(literal.Get({}, /*shape_index=*/{0}), 1.0); @@ -1659,7 +1726,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) { Literal literal = Literal::MoveIntoTuple({}); - ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); + ASSERT_TRUE(literal.shape().IsTuple()); EXPECT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0); } @@ -1719,7 +1786,8 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), - ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); + ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {}), + ShapeUtil::MakeShape(C128, {})})); EXPECT_EQ(tuple.Get({}, {0}), 0.0); EXPECT_EQ(tuple.Get({0}, {1}), false); @@ -1727,6 +1795,7 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { EXPECT_EQ(tuple.Get({0, 0}, {2}), 0); EXPECT_EQ(tuple.Get({1, 0}, {2}), 0); EXPECT_EQ(tuple.Get({}, {3}), complex64(0.0f, 0.0f)); + EXPECT_EQ(tuple.Get({}, {4}), complex128(0.0, 0.0)); } TEST_F(LiteralUtilTest, ProtoRoundTrip) { @@ -1736,6 +1805,8 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { auto vector_int8 = LiteralUtil::CreateR1({-128, 0, 2, 4, 7, 56, 127}); auto vector_uint8 = LiteralUtil::CreateR1({128, 0, 2, 56, 127, 255}); auto vector_c64 = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector_c128 = + LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); auto vector_bfloat16 = LiteralUtil::CreateR1( {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); auto vector_half = @@ -1756,6 +1827,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_int8, to_from_proto(vector_int8)); EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8)); EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); + EXPECT_EQ(vector_c128, to_from_proto(vector_c128)); EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); EXPECT_EQ(tuple, to_from_proto(tuple)); @@ -1890,7 +1962,7 @@ TEST_F(LiteralUtilTest, SortSparseElements) { literal.AppendSparseElement({3, 4, 5}, 3.0); literal.AppendSparseElement({1, 2, 3}, 1.0); literal.SortSparseElements(); - EXPECT_EQ(literal.ToString(false), + EXPECT_EQ(literal.ToString(), "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); } diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index bb5e5e61000d0aca6ab052ac87d2fbcd96e55f70..26b029c8d0c52e38510f9279def7c4af2904931d 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -62,7 +62,7 @@ Literal ConvertType(LiteralSlice literal) { ShapeUtil::ForEachSubshape( literal.shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { if (subshape.element_type() == primitive_util::NativeToPrimitiveType()) { auto src = literal.data(shape_index); @@ -106,12 +106,16 @@ Literal ConvertType(LiteralSlice literal) { switch (primitive_type) { case U8: return LiteralUtil::CreateR0(0); + case U16: + return LiteralUtil::CreateR0(0); case U32: return LiteralUtil::CreateR0(0); case U64: return LiteralUtil::CreateR0(0); case S8: return LiteralUtil::CreateR0(0); + case S16: + return LiteralUtil::CreateR0(0); case S32: return LiteralUtil::CreateR0(0); case S64: @@ -126,11 +130,10 @@ Literal ConvertType(LiteralSlice literal) { return LiteralUtil::CreateR0(0); case C64: return LiteralUtil::CreateR0(0); + case C128: + return LiteralUtil::CreateR0(0); case PRED: return LiteralUtil::CreateR0(false); - case S16: - case U16: - LOG(FATAL) << "u16/s16 literals not yet implemented"; case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 0"; case OPAQUE: @@ -164,6 +167,8 @@ Literal ConvertType(LiteralSlice literal) { return LiteralUtil::CreateR0(1); case C64: return LiteralUtil::CreateR0(1); + case C128: + return LiteralUtil::CreateR0(1); case PRED: return LiteralUtil::CreateR0(true); case S16: @@ -200,6 +205,8 @@ Literal ConvertType(LiteralSlice literal) { -std::numeric_limits::infinity()); case C64: LOG(FATAL) << "C64 element type has no minimum value"; + case C128: + LOG(FATAL) << "C128 element type has no minimum value"; case PRED: return LiteralUtil::CreateR0(false); case S16: @@ -344,6 +351,10 @@ Literal ConvertType(LiteralSlice literal) { new_literal.Set(to_multi_index, literal.Get(from_multi_index)); break; + case C128: + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); + break; default: LOG(FATAL) << "Unhandled primitive element type: " << PrimitiveType_Name(literal.shape().element_type()); @@ -355,7 +366,7 @@ Literal ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::GetFirstScalarLiteral( const LiteralSlice& literal) { - CHECK(ShapeUtil::IsArray(literal.shape())); + CHECK(literal.shape().IsArray()); CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0); switch (literal.shape().element_type()) { case PRED: @@ -392,6 +403,10 @@ Literal ConvertType(LiteralSlice literal) { return LiteralUtil::CreateR0(literal.GetFirstElement()); case U64: return LiteralUtil::CreateR0(literal.GetFirstElement()); + + case C128: + return LiteralUtil::CreateR0( + literal.GetFirstElement()); default: LOG(FATAL) << "Unhandled primitive type " << literal.shape().element_type(); diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index 4eab4fa4290c270697c00be20840cf4e85459183..bad65ac32018fafcc7634b989f1b4b0867aa5c0d 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/metric_table_report.h" -#include #include +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "tensorflow/core/platform/logging.h" @@ -55,7 +55,7 @@ string MetricTableReport::MakeReport(double expected_metric_sum) { const auto metric_greater = [](const Entry& a, const Entry& b) { return a.metric > b.metric; }; - std::sort(entries_.begin(), entries_.end(), metric_greater); + absl::c_sort(entries_, metric_greater); // Create the report AppendLine(); @@ -117,7 +117,7 @@ std::vector MetricTableReport::MakeCategories( auto metric_sum_greater = [](const Category& a, const Category& b) { return a.metric_sum > b.metric_sum; }; - std::sort(categories.begin(), categories.end(), metric_sum_greater); + absl::c_sort(categories, metric_sum_greater); return categories; } @@ -249,7 +249,7 @@ string MetricTableReport::MetricString(double metric) { string output; // Copy leading non-digit characters unconditionally. // This picks up the leading sign. - while (!sp1.empty() && !isdigit(sp1[0])) { + while (!sp1.empty() && !absl::ascii_isdigit(sp1[0])) { output.push_back(sp1[0]); sp1.remove_prefix(1); } diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 0f86f9f35e105713aa3072a9ebf572d33d35d66d..339660cf44fd64fc5859e72255d63762fcf20efe 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -42,8 +42,7 @@ PackedLiteralReader::~PackedLiteralReader() { delete file_; } StatusOr PackedLiteralReader::Read(const Shape& shape, const Layout* layout) { VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) - << " layout: " - << (layout == nullptr ? "" : layout->ShortDebugString()); + << " layout: " << (layout == nullptr ? "" : layout->ToString()); Shape literal_shape = shape; if (layout != nullptr) { TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/parse_flags_from_env.cc b/tensorflow/compiler/xla/parse_flags_from_env.cc index 5b568888d14f21c1330556d017eafba6c8dd2228..e1e22f784172b5f3850f0bc510322dfad9e7f1bb 100644 --- a/tensorflow/compiler/xla/parse_flags_from_env.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/strings/ascii.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" @@ -37,7 +38,7 @@ limitations under the License. namespace xla { -static const char kWS[] = " \t\r\n"; // whitespace +static const char kWS[] = " \t\r\n"; // whitespace // The following struct represents an argv[]-style array, parsed // from data gleaned from the environment. @@ -104,7 +105,8 @@ static void ParseArgvFromString(const string& flag_str, EnvArgv* a) { // Set e to the index just past the end of the flag. size_t e = b; while (e != flag_str.size() && isascii(flag_str[e]) && - (strchr("-_", flag_str[e]) != nullptr || isalnum(flag_str[e]))) { + (strchr("-_", flag_str[e]) != nullptr || + absl::ascii_isalnum(flag_str[e]))) { e++; } if (e != flag_str.size() && flag_str[e] == '=' && @@ -184,6 +186,14 @@ bool ParseFlagsFromEnvAndDieIfUnknown( tensorflow::mutex_lock lock(env_argv_mu); auto* env_argv = &EnvArgvs()[string(envvar)]; SetArgvFromEnv(envvar, env_argv); // a no-op if already initialized + + if (VLOG_IS_ON(1)) { + VLOG(1) << "For env var " << envvar << " found arguments:"; + for (int i = 0; i < env_argv->argc; i++) { + VLOG(1) << " argv[" << i << "] = " << env_argv->argv[i]; + } + } + bool result = tensorflow::Flags::Parse(&env_argv->argc, &env_argv->argv[0], flag_list); diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index b16147e3be71771269d8b7a18528bef3a8c72d99..1eedddf72c1d393cb1b88e589881e24de02ad802 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -15,16 +15,35 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace primitive_util { +int SignificandWidth(PrimitiveType type) { + switch (type) { + case F32: + return std::numeric_limits::digits; + case F64: + return std::numeric_limits::digits; + case BF16: + return kBFloat16MantissaBits + 1; + case F16: + return 11; + default: + LOG(FATAL) << "Not a floating data type " << type; + } +} + bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64 || type == BF16; } -bool IsComplexType(PrimitiveType type) { return type == C64; } +bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; } bool IsSignedIntegralType(PrimitiveType type) { return type == S8 || type == S16 || type == S32 || type == S64; @@ -64,6 +83,9 @@ int BitWidth(PrimitiveType type) { case C64: return 64; + case C128: + return 128; + case TUPLE: LOG(FATAL) << "TUPLE is an invalid type for BitWidth"; @@ -75,10 +97,27 @@ int BitWidth(PrimitiveType type) { } } +xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) { + switch (src_bitwidth) { + case 8: + return xla::U8; + case 16: + return xla::U16; + case 32: + return xla::U32; + case 64: + return xla::U64; + default: + return xla::PRIMITIVE_TYPE_INVALID; + } +} + PrimitiveType ComplexComponentType(PrimitiveType complex_type) { switch (complex_type) { case C64: return F32; + case C128: + return F64; default: LOG(FATAL) << "Primitive type is not complex: " << PrimitiveType_Name(complex_type); @@ -90,5 +129,65 @@ bool IsArrayType(PrimitiveType primitive_type) { primitive_type != OPAQUE && primitive_type != TOKEN; } +// Class to memoize the computation of +// absl::AsciiStrToLower(PrimitiveType_Name(p)) +// for all PrimitiveType values "p" +class PrimitiveTypeNameGenerator { + public: + PrimitiveTypeNameGenerator() { + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i)) { + lowercase_name_[i] = absl::AsciiStrToLower( + PrimitiveType_Name(static_cast(i))); + } + } + } + const string& LowercaseName(PrimitiveType t) { + return lowercase_name_[static_cast(t)]; + } + + private: + string lowercase_name_[PrimitiveType_ARRAYSIZE]; +}; + +const string& LowercasePrimitiveTypeName(PrimitiveType s) { + static auto* gen = new PrimitiveTypeNameGenerator(); + return gen->LowercaseName(s); +} + +namespace { + +// Returns a map from lower-case primitive type name to primitive type. +const std::unordered_map& GetPrimitiveTypeStringMap() { + static std::unordered_map* name_to_type = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) { + auto value = static_cast(i); + (*map)[LowercasePrimitiveTypeName(value)] = value; + } + } + return map; + }(); + return *name_to_type; +} + +} // namespace + +StatusOr StringToPrimitiveType(absl::string_view name) { + const auto& map = GetPrimitiveTypeStringMap(); + auto found = map.find(string(name)); + if (found == map.end()) { + return InvalidArgument("Invalid element type string: \"%s\".", name); + } + return found->second; +} + +bool IsPrimitiveTypeName(absl::string_view name) { + const auto& map = GetPrimitiveTypeStringMap(); + auto found = map.find(string(name)); + return found != map.end(); +} + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 889e9a1ceca675689406d255d348c82c398563aa..295d353003276b4c1731f7d6a378fd1ae0288d3c 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -20,12 +20,19 @@ limitations under the License. #include +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace primitive_util { +// Returns the count of significand (mantissa) bits for float datatypes. +// For non-float datatypes, results in a LOG(FATAL). +int SignificandWidth(PrimitiveType type); + // The number of exponent bits in a BF16 value. const int kBFloat16ExponentBits = 8; @@ -123,6 +130,11 @@ inline PrimitiveType NativeToPrimitiveType() { return C64; } +template <> +inline PrimitiveType NativeToPrimitiveType() { + return C128; +} + bool IsFloatingPointType(PrimitiveType type); bool IsComplexType(PrimitiveType type); @@ -139,6 +151,8 @@ bool IsArrayType(PrimitiveType primitive_type); // Returns the number of bits in the representation for a given type. int BitWidth(PrimitiveType type); +PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth); + // Returns the real, imag component type underlying the given complex type. // LOG(FATAL)'s if complex_type is not complex. PrimitiveType ComplexComponentType(PrimitiveType complex_type); @@ -221,6 +235,22 @@ template <> struct PrimitiveTypeToNative { using type = complex64; }; + +template <> +struct PrimitiveTypeToNative { + using type = complex128; +}; + +// Returns the lower-case name of the given primitive type. +const string& LowercasePrimitiveTypeName(PrimitiveType s); + +// Returns the PrimitiveType matching the given name. The given name is expected +// to be lower-case. +StatusOr StringToPrimitiveType(absl::string_view name); + +// Returns true if the given name is a primitive type string (lower-case). +bool IsPrimitiveTypeName(absl::string_view name); + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util_test.cc b/tensorflow/compiler/xla/primitive_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f765d6da9ef65849fe8ede56ced7597d623cb59 --- /dev/null +++ b/tensorflow/compiler/xla/primitive_util_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/primitive_util.h" + +#include +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +TEST(PrimitiveUtilTest, StringToPrimitiveType) { + auto expect_ok_and_equal = [](const string& str, PrimitiveType expected) { + TF_ASSERT_OK_AND_ASSIGN(PrimitiveType actual, + primitive_util::StringToPrimitiveType(str)); + EXPECT_EQ(expected, actual); + }; + expect_ok_and_equal("f32", F32); + expect_ok_and_equal("tuple", TUPLE); + expect_ok_and_equal("pred", PRED); + expect_ok_and_equal("s32", S32); + + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("F32").status()); + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("Pred").status()); + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("preD").status()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index f22fc8b8499dd4a5329276040331a2ed9e89bea9..4a88a48f2857a327aba3600ca72191e5c7b28585 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ +#include "google/protobuf/duration.pb.h" +#include "absl/time/time.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/protobuf.h" @@ -43,6 +45,20 @@ Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, // dirpath along as-is. void RegisterDirectoryExpander(const std::function& expander); +// Converts an absl::Duration to a google::protobuf::Duration. +inline google::protobuf::Duration ToDurationProto(absl::Duration duration) { + google::protobuf::Duration proto; + proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration)); + proto.set_nanos( + absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); + return proto; +} + +// Converts a google::protobuf::Duration to an absl::Duration. +inline absl::Duration FromDurationProto(google::protobuf::Duration proto) { + return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); +} + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 63ac1c6649210cbae9e238a74e0a45fb8ee4da63..55eacc1c16a76522215d27ac7cf4e801e69c9740 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -3,7 +3,8 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") +load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_python_default_plugins") py_library( name = "xla_client", @@ -17,6 +18,12 @@ py_library( ], ) +pyx_library( + name = "custom_call_for_test", + testonly = True, + srcs = ["custom_call_for_test.pyx"], +) + py_test( name = "xla_client_test", srcs = ["xla_client_test.py"], @@ -24,6 +31,7 @@ py_test( srcs_version = "PY2AND3", tags = ["no_oss"], deps = [ + ":custom_call_for_test", ":xla_client", "//tensorflow/python:platform_test", ], @@ -51,10 +59,6 @@ cc_library( srcs = ["local_computation_builder.cc"], hdrs = ["local_computation_builder.h"], deps = [ - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:client_session", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -66,9 +70,37 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:qr", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", + "//tensorflow/core:lib", + "//third_party/python_runtime:headers", # buildcleaner: keep + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "xrt", + srcs = ["xrt.cc"], + hdrs = ["xrt.h"], + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt/cc:xrt_ops", "//tensorflow/core:framework", @@ -80,11 +112,19 @@ cc_library( tf_py_wrap_cc( name = "pywrap_xla", - srcs = ["xla.i"], + srcs = [ + "xla.i", + ], swig_includes = [ "local_computation_builder.i", + "xla_data.i", "//tensorflow/python:platform/base.i", ], + version_script = select({ + "//tensorflow:darwin": "pywrap_xla_exported_symbols.lds", + "//tensorflow:windows": None, + "//conditions:default": "pywrap_xla_version_script.lds", + }), deps = [ ":local_computation_builder", ":numpy_bridge", @@ -92,7 +132,29 @@ tf_py_wrap_cc( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:cpu_plugin", - ] + if_cuda_is_configured([ - "//tensorflow/compiler/xla/service:gpu_plugin", - ]), + ] + xla_python_default_plugins(), +) + +tf_py_wrap_cc( + name = "pywrap_xrt", + srcs = [ + "xrt.i", + ], + swig_includes = [ + "xla_data.i", + "//tensorflow/python:platform/base.i", + ], + version_script = select({ + "//tensorflow:darwin": "pywrap_xla_exported_symbols.lds", + "//tensorflow:windows": None, + "//conditions:default": "pywrap_xla_version_script.lds", + }), + visibility = ["//visibility:public"], + deps = [ + ":numpy_bridge", + ":xrt", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + ], ) diff --git a/tensorflow/compiler/xla/python/custom_call_for_test.pyx b/tensorflow/compiler/xla/python/custom_call_for_test.pyx new file mode 100644 index 0000000000000000000000000000000000000000..530dffd1755d8438f52569c223525000c97df6ea --- /dev/null +++ b/tensorflow/compiler/xla/python/custom_call_for_test.pyx @@ -0,0 +1,21 @@ +# distutils: language = c++ + +# Test case for defining a XLA custom call target in Cython, and registering +# it via the xla_client SWIG API. + +from cpython.pycapsule cimport PyCapsule_New + +cdef void test_subtract_f32(void* out_ptr, void** data_ptr) nogil: + cdef float a = ((data_ptr[0]))[0] + cdef float b = ((data_ptr[1]))[0] + cdef float* out = (out_ptr) + out[0] = a - b + + +cpu_custom_call_targets = {} + +cdef register_custom_call_target(fn_name, void* fn): + cdef const char* name = "xla._CPU_CUSTOM_CALL_TARGET" + cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL) + +register_custom_call_target(b"test_subtract_f32", (test_subtract_f32)) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 6e2ee866321a070d55a7221c7c68024ceaa93448..c14a01a858af414fc78a5f727372e8fa64cad4b8 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -20,25 +20,22 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tensorflow/cc/client/client_session.h" -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" #include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -47,127 +44,80 @@ limitations under the License. namespace xla { namespace swig { -// TODO(b/118641336): Factor out XRT parts into a small c++ library of their -// own. - -// TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of -// device handles instead of needing to set the number of replicas at XLA -// service initialization time. -tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED); -int g_replica_count GUARDED_BY(g_local_client_mutex) = 1; -LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr; - -string* GetPlatformNameString() { - static string* platform_name_string PT_GUARDED_BY(g_local_client_mutex) = - new string("Host"); - return platform_name_string; -} - -Status InitializeReplicaCount(int replica_count) { - if (replica_count < 1) { - return InvalidArgument("Replica count must be >= 1; got %d.", - replica_count); - } - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return FailedPrecondition( - "Attempted to set the replica count to %d, but a local XLA service was " - "previously created with a replica count of %d.", - replica_count, g_replica_count); - } - g_replica_count = replica_count; - return Status::OK(); -} - -Status InitializePlatformName(const string& platform_name) { - string* g_platform_name = GetPlatformNameString(); - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return FailedPrecondition( - "Attempted to set the platform name to %s, but a local XLA service was " - "previously created with a platform name of %s.", - platform_name, *g_platform_name); +Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) { + const char* name = "xla._CPU_CUSTOM_CALL_TARGET"; + if (!PyCapsule_IsValid(capsule, name)) { + return InvalidArgument( + "Argument to RegisterCpuCustomCallTargetRegistry was not a " + "xla._CPU_CUSTOM_CALL_TARGET capsule."); } - TF_RETURN_IF_ERROR(PlatformUtil::GetPlatform(platform_name).status()); - *g_platform_name = platform_name; + void* fn_ptr = PyCapsule_GetPointer(capsule, name); + CHECK(fn_ptr != nullptr); + cpu::CustomCallTargetRegistry::Global()->Register( + std::string(fn_name.begin(), fn_name.end()), fn_ptr); return Status::OK(); } -int GetReplicaCount() { - tensorflow::mutex_lock lock(g_local_client_mutex); - return g_replica_count; -} +LocalClient::LocalClient(xla::LocalClient* client) : client_(client) {} -LocalClient* GetOrCreateLocalClient() { - string* platform_name = GetPlatformNameString(); - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return g_local_client; +/* static */ StatusOr LocalClient::Get( + const string& platform_name) { + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform(platform_name)); + if (platform->VisibleDeviceCount() <= 0) { + return InvalidArgument("Platform %s has no visible devices.", + platform_name); } LocalClientOptions options; - options.set_platform(PlatformUtil::GetPlatform(*platform_name).ValueOrDie()); - options.set_number_of_replicas(g_replica_count); - g_local_client = ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie(); - CHECK(g_local_client != nullptr); - return g_local_client; + options.set_platform(platform); + TF_ASSIGN_OR_RETURN(xla::LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(options)); + CHECK(client != nullptr); + return LocalClient(client); } -Status TransferToInfeedLocal(const Literal& literal) { - VLOG(1) << "Infeeding literal without replica number; shape: " - << literal.shape(); - LocalClient* client = GetOrCreateLocalClient(); - return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0); -} +// Returns the number of devices known to the XLA client. +int LocalClient::DeviceCount() const { return client_->device_count(); } -Status TransferToInfeedLocalReplica(const Literal& literal, - int replica_number) { - VLOG(1) << "Infeeding shape " << literal.shape() - << " to replica number: " << replica_number; - LocalClient* client = GetOrCreateLocalClient(); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - return client->TransferToInfeedLocal(literal, device_ordinal); +Status LocalClient::TransferToInfeed(const Literal& literal, + int device_ordinal) { + VLOG(1) << "Infeeding literal to device " << device_ordinal + << "; shape: " << literal.shape(); + return client_->TransferToInfeed(literal, device_ordinal); } -StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, - int replica_number) { - VLOG(1) << "Outfeeding literal from replica number: " << replica_number - << " shape: " << shape; - LocalClient* client = GetOrCreateLocalClient(); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - return client->TransferFromOutfeedLocal(shape, device_ordinal); -} - -static StatusOr ToBuffer(LocalClient* client, - int device_ordinal, - const Literal& arg) { - return client->LiteralToShapedBuffer(arg, device_ordinal, - client->backend().memory_allocator()); +StatusOr LocalClient::TransferFromOutfeed(const Shape& shape, + int device_ordinal) { + VLOG(1) << "Outfeeding literal from device " << device_ordinal + << "; shape: " << shape; + return client_->TransferFromOutfeed(&shape, device_ordinal); } /* static */ StatusOr LocalShapedBuffer::FromLiteral( const Literal& argument, const absl::optional& shape_with_layout, - int replica_number) { - LocalClient* client = GetOrCreateLocalClient(); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - VLOG(1) << "Creating shaped buffer from literal on replica/ordinal: " - << replica_number << "/" << device_ordinal; + const LocalClient& client, int device_ordinal) { + VLOG(1) << "Creating shaped buffer from literal on device ordinal: " + << device_ordinal; + auto literal_to_buffer = [&](const Literal& arg) { + return client.client()->LiteralToShapedBuffer( + arg, device_ordinal, client.client()->backend().memory_allocator()); + }; + StatusOr buf = [&] { if (shape_with_layout) { Literal relaid = argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, device_ordinal, relaid); + return literal_to_buffer(relaid); } - return ToBuffer(client, device_ordinal, argument); + return literal_to_buffer(argument); }(); TF_RETURN_IF_ERROR(buf.status()); - return new LocalShapedBuffer(std::move(buf).ValueOrDie()); + return new LocalShapedBuffer(std::move(buf).ValueOrDie(), client.client()); } -LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer) - : shaped_buffer_(std::move(shaped_buffer)) {} +LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, + xla::LocalClient* client) + : shaped_buffer_(std::move(shaped_buffer)), client_(client) {} const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const { return &shaped_buffer_; @@ -180,8 +130,7 @@ const Shape& LocalShapedBuffer::shape() const { } StatusOr LocalShapedBuffer::ToLiteral() const { - LocalClient* client = GetOrCreateLocalClient(); - return client->ShapedBufferToLiteral(*shaped_buffer()); + return client_->ShapedBufferToLiteral(*shaped_buffer()); } LocalShapedBufferTuple::LocalShapedBufferTuple( @@ -212,141 +161,94 @@ StatusOr LocalShapedBufferTuple::Release(int i) { int64 LocalShapedBufferTuple::size() const { return elements_.size(); } -XrtAllocation::XrtAllocation(int64 handle, Shape shape, - const string& session_target) - : handle_(handle), shape_(shape), session_target_(session_target) {} - -XrtAllocation::~XrtAllocation() { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto allocation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto release = - tensorflow::ops::XRTReleaseAllocationHandle(root, allocation_handle); - if (!root.status().ok()) { - LOG(ERROR) << root.status(); - return; - } +StatusOr LocalShapedBuffer::DestructureTuple() { + const Shape tuple_shape = shape(); - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({allocation_handle, handle()}); - std::vector outputs; - auto status = session.Run(inputs, {}, {release}, &outputs); - if (!status.ok()) { - LOG(ERROR) << status; - return; + if (!tuple_shape.IsTuple()) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString(tuple_shape)); } -} - -/* static */ -StatusOr XrtAllocation::FromLiteral( - const Literal& argument, const string& session_target) { - xrt::XLAAllocation alloc; - alloc.set_device_ordinal(0); - *alloc.mutable_value() = argument.ToProto(); - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto literal_string = - tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string); - TF_RETURN_IF_ERROR(root.status()); - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({literal_string, alloc.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs)); + DeviceMemoryAllocator* allocator = shaped_buffer()->memory_allocator(); + ShapedBuffer tuple_buffer = Release(); - int64 handle = outputs[0].scalar()(); - return new XrtAllocation(handle, argument.shape(), session_target); -} - -const int64 XrtAllocation::handle() const { return handle_; } - -const Shape& XrtAllocation::shape() const { return shape_; } - -StatusOr XrtAllocation::ToLiteral() const { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto allocation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle); - TF_RETURN_IF_ERROR(root.status()); + // Extract some metadata we use to construct scoped buffers. + const se::Platform* platform = tuple_buffer.platform(); + int device_ordinal = tuple_buffer.device_ordinal(); - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({allocation_handle, handle()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {read_literal}, &outputs)); + ShapeTree& shape_tree = tuple_buffer.buffers(); + std::vector results; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + // Create a shaped buffer for this destructured tuple element. + const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); + VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; + ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); - xla::LiteralProto response; - TF_RET_CHECK(response.ParseFromString(outputs[0].scalar()())); - return Literal::CreateFromProto(response); -} + ShapeUtil::ForEachSubshape( + subshape, [&](const Shape& s, const ShapeIndex& index) { + ShapeIndex original(index); + original.push_front(i); + se::DeviceMemoryBase* device_memory = + shape_tree.mutable_element(original); + shaped_buffer.set_buffer(*device_memory, index); + *device_memory = se::DeviceMemoryBase(); + }); -XrtAllocationTuple::XrtAllocationTuple(std::vector elements) - : elements_(std::move(elements)) { - for (auto* element : elements_) { - CHECK(element != nullptr); + VLOG(3) << "Completed tuple element: " << i; + results.push_back(new LocalShapedBuffer( + ScopedShapedBuffer(std::move(shaped_buffer), allocator), client_)); } + // Deallocate the root buffer. + se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); + TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); + return new LocalShapedBufferTuple(std::move(results)); } -XrtAllocationTuple::~XrtAllocationTuple() { - for (XrtAllocation* element : elements_) { - if (element != nullptr) { - delete element; - } - } -} +LocalExecutable::LocalExecutable( + std::unique_ptr executable, + xla::DeviceAssignment device_assignment, xla::LocalClient* client) + : executable_(std::move(executable)), + device_assignment_(std::move(device_assignment)), + client_(client) {} -StatusOr XrtAllocationTuple::Release(int i) { - XrtAllocation* element = elements_[i]; - if (element == nullptr) { - return InvalidArgument("Attempted to release already-released element %d.", - i); +std::vector LocalExecutable::DeviceOrdinals() const { + int num_replicas = device_assignment_.replica_count(); + std::vector device_ordinals; + device_ordinals.reserve(num_replicas); + for (int i = 0; i < num_replicas; ++i) { + device_ordinals.push_back(device_assignment_(i, 0)); } - elements_[i] = nullptr; - return element; + return device_ordinals; } -int64 XrtAllocationTuple::size() const { return elements_.size(); } - -CompiledLocalComputation::CompiledLocalComputation( - std::unique_ptr executable) - : executable_(std::move(executable)) {} - -StatusOr CompiledLocalComputation::Execute( +StatusOr LocalExecutable::Execute( absl::Span argument_handles) { - LocalClient* client = GetOrCreateLocalClient(); - StatusOr device_ordinal_status = client->ReplicaNumberToDeviceOrdinal(0); + if (num_replicas() != 1) { + return InvalidArgument( + "Attempted to execute computation with %d replicas using Execute()", + num_replicas()); + } StatusOr result_buffer_status; - if (!device_ordinal_status.ok()) { - result_buffer_status = device_ordinal_status.status(); - } else { - const int device_ordinal = device_ordinal_status.ValueOrDie(); - VLOG(3) << "Replica 0 mapped to device ordinal for execution: " - << device_ordinal; + const int device_ordinal = device_assignment_(0, 0); + VLOG(3) << "Replica 0 mapped to device ordinal for execution: " + << device_ordinal; - std::vector argument_buffers; - argument_buffers.reserve(argument_handles.size()); - for (auto& handle : argument_handles) { - argument_buffers.push_back(handle->shaped_buffer()); - } - - DeviceAssignment device_assignment = - client->backend() - .computation_placer() - ->AssignDevices(1, /*computation_count=*/1) - .ConsumeValueOrDie(); + std::vector argument_buffers; + argument_buffers.reserve(argument_handles.size()); + for (auto& handle : argument_handles) { + argument_buffers.push_back(handle->shaped_buffer()); + } - ExecutableRunOptions options; - options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); - options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client_->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client_->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment_); - result_buffer_status = executable_->Run(argument_buffers, options); - } + result_buffer_status = executable_->Run(argument_buffers, options); if (!result_buffer_status.ok()) { return InternalError( @@ -354,34 +256,30 @@ StatusOr CompiledLocalComputation::Execute( "%s.", result_buffer_status.status().ToString()); } - return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie()); + return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(), + client_); } -StatusOr CompiledLocalComputation::ExecutePerReplica( +StatusOr LocalExecutable::ExecutePerReplica( absl::Span> argument_handles) { - LocalClient* client = GetOrCreateLocalClient(); - const int num_replicas = GetReplicaCount(); + const int num_devices = client_->device_count(); - if (argument_handles.size() != num_replicas) { + if (argument_handles.size() != num_replicas()) { return InvalidArgument( "Attempted to execute with %d replicas when replica count is %d", - argument_handles.size(), num_replicas); + argument_handles.size(), num_devices); + } + if (argument_handles.size() > num_devices) { + return InvalidArgument( + "Attempted to execute with %d replicas when device count is %d", + argument_handles.size(), num_devices); } - VLOG(1) << "Executing with " << num_replicas << " replicas."; - - // Each replica populates a StatusOr result, but only the output value of - // replica zero is returned. - std::vector> results(num_replicas); - auto execute = [this, client, num_replicas, &argument_handles, - &results](int replica) { - StatusOr device_ordinal_status = - client->ReplicaNumberToDeviceOrdinal(replica); - if (!device_ordinal_status.ok()) { - results[replica] = device_ordinal_status.status(); - return; - } - const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(1) << "Executing with " << num_replicas() << " replicas."; + + std::vector> results(num_replicas()); + auto execute = [this, &argument_handles, &results](int replica) { + const int device_ordinal = device_assignment_(replica, 0); VLOG(3) << "Replica " << replica << " mapped to device ordinal for execution: " << device_ordinal; @@ -391,41 +289,35 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( argument_buffers.push_back(handle->shaped_buffer()); } - DeviceAssignment device_assignment = - client->backend() - .computation_placer() - ->AssignDevices(num_replicas, /*computation_count=*/1) - .ConsumeValueOrDie(); - ExecutableRunOptions options; options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); + options.set_allocator(client_->backend().memory_allocator()); options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); + client_->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment_); StatusOr result_buffer_status = executable_->Run(argument_buffers, options); results[replica] = std::move(result_buffer_status); }; - if (num_replicas == 1) { + if (num_replicas() == 1) { // Fast-path if there is only one replica — run the computation on the // current thread. execute(0); } else { // TODO(phawkins): don't recreate the threadpool for each execution. tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", - num_replicas - 1); + num_replicas() - 1); - for (int replica = 0; replica < num_replicas - 1; ++replica) { + for (int replica = 0; replica < num_replicas() - 1; ++replica) { pool.Schedule([&execute, replica] { execute(replica); }); } - execute(num_replicas - 1); + execute(num_replicas() - 1); } - std::vector wrapped_results(num_replicas); - for (int replica = 0; replica < num_replicas; ++replica) { + std::vector wrapped_results(num_replicas()); + for (int replica = 0; replica < num_replicas(); ++replica) { auto& statusor = results[replica]; if (!statusor.ok()) { return InternalError( @@ -434,151 +326,43 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( replica, statusor.status().ToString()); } wrapped_results[replica] = - new LocalShapedBuffer(std::move(statusor).ValueOrDie()); + new LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_); } return new LocalShapedBufferTuple(std::move(wrapped_results)); } -static StatusOr GetReturnValueShape(const XlaComputation& computation) { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - computation.GetProgramShape()); - return std::move(*program_shape.mutable_result()); -} - -CompiledXrtComputation::CompiledXrtComputation( - const ProgramShape& program_shape, int64 handle, - const string& session_target) - : program_shape_(program_shape), - handle_(handle), - session_target_(session_target) {} - -CompiledXrtComputation::~CompiledXrtComputation() { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto computation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto release = - tensorflow::ops::XRTReleaseCompilationHandle(root, computation_handle); - if (!root.status().ok()) { - LOG(ERROR) << root.status(); - return; - } - - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({computation_handle, handle()}); - std::vector outputs; - auto status = session.Run(inputs, {}, {release}, &outputs); - if (!status.ok()) { - LOG(ERROR) << status; - return; - } -} - -StatusOr CompiledXrtComputation::Execute( - absl::Span argument_handles) { - const int num_expected_arguments = program_shape().parameters().size(); - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - std::vector arguments; - arguments.reserve(num_expected_arguments); - for (int i = 0; i < num_expected_arguments; ++i) { - arguments.push_back( - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64)); - } - auto computation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto execution_config = - tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto execute = tensorflow::ops::XRTExecute(root, computation_handle, - execution_config, arguments); - TF_RETURN_IF_ERROR(root.status()); - - TF_RET_CHECK(argument_handles.size() == arguments.size()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(false); - e.set_release_compilation_handle(false); - - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - for (int i = 0; i < arguments.size(); ++i) { - inputs.insert({arguments[i], argument_handles[i]->handle()}); - } - inputs.insert({computation_handle, handle()}); - inputs.insert({execution_config, e.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs)); - - int64 output = outputs[0].scalar()(); - return new XrtAllocation(output, program_shape().result(), session_target_); -} - -const ProgramShape& CompiledXrtComputation::program_shape() const { - return program_shape_; -} - -int64 CompiledXrtComputation::handle() const { return handle_; } - -LocalComputation::LocalComputation(XlaComputation computation) +Computation::Computation(XlaComputation computation) : computation_(std::move(computation)) {} -StatusOr LocalComputation::Compile( +StatusOr Computation::Compile( const std::vector& argument_shapes, - const ExecutableBuildOptions* build_options) { + const ExecutableBuildOptions* build_options, const LocalClient& client) { std::vector argument_shape_pointers; argument_shape_pointers.reserve(argument_shapes.size()); for (auto& argument_shape : argument_shapes) { argument_shape_pointers.push_back(&argument_shape); } - LocalClient* client = GetOrCreateLocalClient(); ExecutableBuildOptions options; if (build_options != nullptr) { options = *build_options; } TF_ASSIGN_OR_RETURN( auto local_executable, - client->Compile(computation_, argument_shape_pointers, options)); - return new CompiledLocalComputation(std::move(local_executable)); -} - -StatusOr LocalComputation::CompileForXrt( - const std::vector& argument_shapes, const string& session_target) { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto compile = tensorflow::ops::XRTCompile(root, program); - TF_RETURN_IF_ERROR(root.status()); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - ProgramShape shapes; - for (auto& shape : argument_shapes) { - *shapes.add_parameters() = shape; - } - TF_ASSIGN_OR_RETURN(*shapes.mutable_result(), GetReturnValueShape()); - LayoutUtil::SetToDefaultLayout(&shapes); - *config->mutable_program_shape() = shapes.ToProto(); - auto snapshot = computation().Snapshot().ValueOrDie(); - *c.mutable_hlo_snapshot() = *snapshot; - - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({program, c.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {compile.handle}, &outputs)); + client.client()->Compile(computation_, argument_shape_pointers, options)); + TF_ASSIGN_OR_RETURN( + DeviceAssignment device_assignment, + client.client()->backend().computation_placer()->AssignDevices( + options.num_replicas(), /*computation_count=*/1)); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - computation().GetProgramShape()); - int64 handle = outputs[0].scalar()(); - return new CompiledXrtComputation(program_shape, handle, session_target); + return new LocalExecutable(std::move(local_executable), + std::move(device_assignment), client.client()); } -const XlaComputation& LocalComputation::computation() const { - return computation_; -} +const XlaComputation& Computation::computation() const { return computation_; } -string LocalComputation::GetSerializedProto() const { +string Computation::GetSerializedProto() const { string result; if (!computation_.proto().SerializeToString(&result)) { LOG(ERROR) << "Failed to serialize the HloModuleProto."; @@ -587,123 +371,171 @@ string LocalComputation::GetSerializedProto() const { return result; } -StatusOr LocalComputation::GetReturnValueShape() const { - return swig::GetReturnValueShape(computation_); +StatusOr Computation::GetHloText() const { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation_.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(computation_.proto(), module_config)); + HloPrintOptions options; + options = HloPrintOptions::ShortParsable(); + options.set_print_large_constants(false); + return hlo_module->ToString(options); +} + +StatusOr Computation::GetHloDotGraph() const { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation_.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(computation_.proto(), module_config)); + hlo_graph_dumper::DotGraphOptions options; + options.debug_options = &hlo_module->config().debug_options(); + return hlo_graph_dumper::HloComputationToDotGraph( + *hlo_module->entry_computation(), options); +} + +StatusOr Computation::GetProgramShape() const { + return computation_.GetProgramShape(); +} + +StatusOr Computation::GetReturnValueShape() const { + TF_ASSIGN_OR_RETURN(ProgramShape shape, computation_.GetProgramShape()); + return std::move(*shape.mutable_result()); } LocalOp::LocalOp(const XlaOp& op) : op_(op) {} const XlaOp& LocalOp::op() const { return op_; } -LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) +ComputationBuilder::ComputationBuilder(const string& computation_name) : builder_(computation_name) {} -void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { +void ComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { builder_.SetOpMetadata(metadata); } -void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } +void ComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } -StatusOr LocalComputationBuilder::Build() { +StatusOr ComputationBuilder::Build() { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { +LocalOp ComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, const string& name) { return xla::Parameter(&builder_, parameter_number, shape, name); } -StatusOr LocalComputationBuilder::BuildWithRoot( - const LocalOp& root) { +StatusOr ComputationBuilder::BuildWithRoot(const LocalOp& root) { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build(root.op())); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -StatusOr LocalComputationBuilder::GetShape(const LocalOp& operand) { +StatusOr ComputationBuilder::GetShape(const LocalOp& operand) { return builder_.GetShape(operand.op()); } -StatusOr LocalComputationBuilder::GetReturnValueShape() { +StatusOr ComputationBuilder::GetReturnValueShape() { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape()); return program_shape.result(); } -LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { +LocalOp ComputationBuilder::Infeed(const Shape& shape) { return xla::Infeed(&builder_, shape); } -void LocalComputationBuilder::Outfeed(const LocalOp& operand, - const Shape& shape, - const string& outfeed_config) { +void ComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, + const string& outfeed_config) { xla::Outfeed(operand.op(), shape, outfeed_config); } -LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { +LocalOp ComputationBuilder::ConstantLiteral(const Literal& literal) { return xla::ConstantLiteral(&builder_, literal); } -LocalOp LocalComputationBuilder::Broadcast( - const LocalOp& operand, absl::Span broadcast_sizes) { +LocalOp ComputationBuilder::Iota(PrimitiveType element_type, int64 size) { + return xla::Iota(&builder_, element_type, size); +} + +LocalOp ComputationBuilder::BroadcastedIota(const Shape& shape, + int64 dimension) { + return xla::Iota(&builder_, shape, dimension); +} + +LocalOp ComputationBuilder::Broadcast(const LocalOp& operand, + absl::Span broadcast_sizes) { return xla::Broadcast(operand.op(), broadcast_sizes); } -LocalOp LocalComputationBuilder::BroadcastInDim( +LocalOp ComputationBuilder::BroadcastInDim( const LocalOp& operand, absl::Span out_dim_sizes, absl::Span broadcast_dimensions) { return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions); } -LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, - const LocalOp& padding_value, - const PaddingConfig& padding_config) { +LocalOp ComputationBuilder::Pad(const LocalOp& operand, + const LocalOp& padding_value, + const PaddingConfig& padding_config) { return xla::Pad(operand.op(), padding_value.op(), padding_config); } -LocalOp LocalComputationBuilder::Reshape(const LocalOp& operand, - absl::Span dimensions, - absl::Span new_sizes) { +LocalOp ComputationBuilder::Reshape(const LocalOp& operand, + absl::Span dimensions, + absl::Span new_sizes) { return xla::Reshape(operand.op(), dimensions, new_sizes); } -LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand, - absl::Span dimensions) { +LocalOp ComputationBuilder::Collapse(const LocalOp& operand, + absl::Span dimensions) { return xla::Collapse(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { - return xla::CrossReplicaSum(operand.op()); +LocalOp ComputationBuilder::AllToAll( + const LocalOp& operand, int64 split_dimension, int64 concat_dimension, + int64 split_count, absl::Span replica_groups) { + std::vector rg(replica_groups.size()); + for (int i = 0; i < replica_groups.size(); ++i) { + rg.push_back(replica_groups[i]); + } + return xla::AllToAll(operand.op(), split_dimension, concat_dimension, + split_count, rg); +} + +LocalOp ComputationBuilder::CrossReplicaSum( + const LocalOp& operand, absl::Span replica_groups) { + return xla::CrossReplicaSum(operand.op(), replica_groups); } -LocalOp LocalComputationBuilder::Slice(const LocalOp& operand, - absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides) { +LocalOp ComputationBuilder::Slice(const LocalOp& operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return xla::Slice(operand.op(), start_indices, limit_indices, strides); } -LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, - int64 start_index, - int64 limit_index, int64 stride, - int64 dimno) { +LocalOp ComputationBuilder::SliceInDim(const LocalOp& operand, + int64 start_index, int64 limit_index, + int64 stride, int64 dimno) { return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno); } -LocalOp LocalComputationBuilder::DynamicSlice( - const LocalOp& operand, const LocalOp& start_indices, - absl::Span slice_sizes) { +LocalOp ComputationBuilder::DynamicSlice(const LocalOp& operand, + const LocalOp& start_indices, + absl::Span slice_sizes) { return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } -LocalOp LocalComputationBuilder::DynamicUpdateSlice( - const LocalOp& operand, const LocalOp& update, - const LocalOp& start_indices) { +LocalOp ComputationBuilder::DynamicUpdateSlice(const LocalOp& operand, + const LocalOp& update, + const LocalOp& start_indices) { return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op()); } -LocalOp LocalComputationBuilder::ConcatInDim(absl::Span operands, - int64 dimension) { +LocalOp ComputationBuilder::ConcatInDim(absl::Span operands, + int64 dimension) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -712,18 +544,18 @@ LocalOp LocalComputationBuilder::ConcatInDim(absl::Span operands, return xla::ConcatInDim(&builder_, xla_ops, dimension); } -LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( - const LocalOp& operand, const LocalComputation& select, +LocalOp ComputationBuilder::SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const Computation& select, absl::Span window_dimensions, absl::Span window_strides, absl::Span> padding, const LocalOp& source, - const LocalOp& init_value, const LocalComputation& scatter) { + const LocalOp& init_value, const Computation& scatter) { return xla::SelectAndScatterWithGeneralPadding( operand.op(), select.computation(), window_dimensions, window_strides, padding, source.op(), init_value.op(), scatter.computation()); } -LocalOp LocalComputationBuilder::Tuple(absl::Span elements) { +LocalOp ComputationBuilder::Tuple(absl::Span elements) { std::vector xla_ops; xla_ops.reserve(elements.size()); for (const auto& op : elements) { @@ -733,22 +565,22 @@ LocalOp LocalComputationBuilder::Tuple(absl::Span elements) { return xla::Tuple(&builder_, xla_ops); } -LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, - int64 index) { +LocalOp ComputationBuilder::GetTupleElement(const LocalOp& tuple_data, + int64 index) { return xla::GetTupleElement(tuple_data.op(), index); } -LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { +LocalOp ComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { return xla::Dot(lhs.op(), rhs.op()); } -LocalOp LocalComputationBuilder::DotGeneral( +LocalOp ComputationBuilder::DotGeneral( const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } -LocalOp LocalComputationBuilder::ConvGeneralDilated( +LocalOp ComputationBuilder::ConvGeneralDilated( const LocalOp& lhs, const LocalOp& rhs, absl::Span window_strides, absl::Span> padding, @@ -760,18 +592,18 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated( feature_group_count); } -LocalOp LocalComputationBuilder::ConvertElementType( - const LocalOp& operand, PrimitiveType new_element_type) { +LocalOp ComputationBuilder::ConvertElementType(const LocalOp& operand, + PrimitiveType new_element_type) { return xla::ConvertElementType(operand.op(), new_element_type); } -LocalOp LocalComputationBuilder::BitcastConvertType( - const LocalOp& operand, PrimitiveType new_element_type) { +LocalOp ComputationBuilder::BitcastConvertType(const LocalOp& operand, + PrimitiveType new_element_type) { return xla::BitcastConvertType(operand.op(), new_element_type); } -LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, - absl::Span operands) { +LocalOp ComputationBuilder::Call(const Computation& local_computation, + absl::Span operands) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -780,19 +612,34 @@ LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, return xla::Call(&builder_, local_computation.computation(), xla_ops); } -LocalOp LocalComputationBuilder::Transpose( - const LocalOp& operand, absl::Span permutation) { +LocalOp ComputationBuilder::CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape_with_layout, + const std::vector& operand_shapes_with_layout, + const string& opaque) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return xla::CustomCallWithLayout(&builder_, call_target_name, xla_ops, + shape_with_layout, + operand_shapes_with_layout, opaque); +} + +LocalOp ComputationBuilder::Transpose(const LocalOp& operand, + absl::Span permutation) { return xla::Transpose(operand.op(), permutation); } -LocalOp LocalComputationBuilder::Rev(const LocalOp& operand, - absl::Span dimensions) { +LocalOp ComputationBuilder::Rev(const LocalOp& operand, + absl::Span dimensions) { return xla::Rev(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::Map(absl::Span operands, - const LocalComputation& local_computation, - absl::Span dimensions) { +LocalOp ComputationBuilder::Map(absl::Span operands, + const Computation& local_computation, + absl::Span dimensions) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -803,17 +650,17 @@ LocalOp LocalComputationBuilder::Map(absl::Span operands, dimensions); } -LocalOp LocalComputationBuilder::Reduce( +LocalOp ComputationBuilder::Reduce( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions_to_reduce) { return xla::Reduce(operand.op(), init_value.op(), local_computation.computation(), dimensions_to_reduce); } -LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( +LocalOp ComputationBuilder::ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, @@ -825,56 +672,92 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( padding); } -LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, - const LocalOp& sigma, - const Shape& shape) { +LocalOp ComputationBuilder::RngNormal(const LocalOp& mu, const LocalOp& sigma, + const Shape& shape) { return xla::RngNormal(mu.op(), sigma.op(), shape); } -LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, - const Shape& shape) { +LocalOp ComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, + const Shape& shape) { return xla::RngUniform(a.op(), b.op(), shape); } -LocalOp LocalComputationBuilder::While(const LocalComputation& condition, - const LocalComputation& body, - const LocalOp& init) { +LocalOp ComputationBuilder::While(const Computation& condition, + const Computation& body, + const LocalOp& init) { return xla::While(condition.computation(), body.computation(), init.op()); } -LocalOp LocalComputationBuilder::Conditional( - const LocalOp& predicate, const LocalOp& true_operand, - const LocalComputation& true_computation, const LocalOp& false_operand, - const LocalComputation& false_computation) { +LocalOp ComputationBuilder::Conditional(const LocalOp& predicate, + const LocalOp& true_operand, + const Computation& true_computation, + const LocalOp& false_operand, + const Computation& false_computation) { return xla::Conditional(predicate.op(), true_operand.op(), true_computation.computation(), false_operand.op(), false_computation.computation()); } -StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { +StatusOr ComputationBuilder::IsConstant(const LocalOp& operand) { return builder_.IsConstant(operand.op()); } -LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { +LocalOp ComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { return xla::Sort(operand.op(), {}, dimension); } -LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, - const LocalOp& values, - int64 dimension) { +LocalOp ComputationBuilder::SortKeyVal(const LocalOp& keys, + const LocalOp& values, int64 dimension) { return xla::Sort(keys.op(), {values.op()}, dimension); } -StatusOr LocalComputationBuilder::BuildConstantSubGraph( +LocalOp ComputationBuilder::Cholesky(const LocalOp& a) { + return xla::Cholesky(a.op()); +} + +LocalOp ComputationBuilder::QR(const LocalOp& a, bool full_matrices) { + XlaBuilder* builder = a.op().builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices)); + return xla::Tuple(builder, {qr.q, qr.r}); + }); +} + +LocalOp ComputationBuilder::TriangularSolve(const LocalOp& a, const LocalOp& b, + bool left_side, bool lower, + bool unit_diagonal, + int transpose_a) { + return xla::TriangularSolve( + a.op(), b.op(), left_side, lower, unit_diagonal, + xla::TriangularSolveOptions::Transpose(transpose_a)); +} + +LocalOp ComputationBuilder::Gather( + const LocalOp& input, const LocalOp& start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes) { + return xla::Gather(input.op(), start_indices.op(), dimension_numbers, + slice_sizes); +} + +LocalOp ComputationBuilder::Scatter( + const LocalOp& input, const LocalOp& scatter_indices, + const LocalOp& updates, const Computation& update_computation, + const ScatterDimensionNumbers& dimension_numbers) { + return xla::Scatter(input.op(), scatter_indices.op(), updates.op(), + update_computation.computation(), dimension_numbers); +} + +StatusOr ComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.BuildConstantSubGraph(operand.op())); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -#define _FORWARD(method_name, return_sig, args_sig, args) \ - return_sig LocalComputationBuilder::method_name args_sig { \ - return xla::method_name args; \ +#define _FORWARD(method_name, return_sig, args_sig, args) \ + return_sig ComputationBuilder::method_name args_sig { \ + return xla::method_name args; \ } #define _FORWARD_UNOP(method_name) \ @@ -916,6 +799,7 @@ _FORWARD_BINOP(Atan2) _FORWARD_BINOP(Pow) _FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) +_FORWARD_UNOP(Clz) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) _FORWARD_UNOP(Expm1) @@ -961,108 +845,9 @@ void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) { delete local_shaped_buffer; } -void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; } - -void DeleteCompiledLocalComputation(CompiledLocalComputation* computation) { - delete computation; -} - -void DeleteCompiledXrtComputation(CompiledXrtComputation* computation) { - delete computation; -} - -void DeleteLocalComputation(LocalComputation* computation) { - delete computation; -} - -StatusOr DestructureLocalShapedBufferTuple( - LocalShapedBuffer* local_shaped_buffer) { - const Shape tuple_shape = local_shaped_buffer->shape(); - - if (!ShapeUtil::IsTuple(tuple_shape)) { - return InvalidArgument( - "Attemped to destructure a LocalShapedBuffer that did not have a tuple " - "shape; shape: %s", - ShapeUtil::HumanString(tuple_shape)); - } +void DeleteLocalExecutable(LocalExecutable* computation) { delete computation; } - DeviceMemoryAllocator* allocator = - local_shaped_buffer->shaped_buffer()->memory_allocator(); - ShapedBuffer tuple_buffer = local_shaped_buffer->Release(); - - // Extract some metadata we use to construct scoped buffers. - const se::Platform* platform = tuple_buffer.platform(); - int device_ordinal = tuple_buffer.device_ordinal(); - - ShapeTree& shape_tree = tuple_buffer.buffers(); - std::vector results; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { - // Create a shaped buffer for this destructured tuple element. - const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); - VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; - ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); - - ShapeUtil::ForEachSubshape( - subshape, [&](const Shape& s, const ShapeIndex& index) { - ShapeIndex original(index); - original.push_front(i); - se::DeviceMemoryBase* device_memory = - shape_tree.mutable_element(original); - shaped_buffer.set_buffer(*device_memory, index); - *device_memory = se::DeviceMemoryBase(); - }); - - VLOG(3) << "Completed tuple element: " << i; - results.push_back(new LocalShapedBuffer( - ScopedShapedBuffer(std::move(shaped_buffer), allocator))); - } - // Deallocate the root buffer. - se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); - TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); - return new LocalShapedBufferTuple(std::move(results)); -} - -StatusOr DestructureXrtAllocationTuple( - XrtAllocation* allocation, const string& session_target) { - const Shape& tuple_shape = allocation->shape(); - - if (!ShapeUtil::IsTuple(tuple_shape)) { - return InvalidArgument( - "Attemped to destructure a LocalShapedBuffer that did not have a tuple " - "shape; shape: %s", - ShapeUtil::HumanString(tuple_shape)); - } - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto base_handle = tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto shape_index = tensorflow::ops::Placeholder(root, tensorflow::DT_INT32); - auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index); - TF_RETURN_IF_ERROR(root.status()); - - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - std::vector results; - for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { - inputs.clear(); - inputs.insert({base_handle, allocation->handle()}); - inputs.insert({shape_index, {i}}); - std::vector outputs; - auto status = session.Run(inputs, {subtuple}, &outputs); - if (!status.ok()) { - // Clean up before returning non-ok status. - for (int j = 0; j < results.size(); ++j) { - delete results[j]; - } - return status; - } - const int64 subtuple_handle = outputs[0].scalar()(); - const Shape& subtuple_shape = - ShapeUtil::GetTupleElementShape(tuple_shape, i); - results.push_back( - new XrtAllocation(subtuple_handle, subtuple_shape, session_target)); - } - return new XrtAllocationTuple(std::move(results)); -} +void DeleteComputation(Computation* computation) { delete computation; } } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 149e44570df5c6a3df88bbe2ffa779be47842d82..66b1cce7fb598388af40940ea2ed52ac2f8ee8e1 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -19,10 +19,9 @@ limitations under the License. #include #include +#include + #include "absl/types/span.h" -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -33,37 +32,42 @@ limitations under the License. namespace xla { namespace swig { -// Initializes the number of replicas that XLA will be initialized with (when -// first obtaining a handle to the local XLA service). If this is called after -// the handle to the local XLA service has been established, then an error is -// returned. -Status InitializeReplicaCount(int replica_count); - -// Initializes the platform name that XLA will be initialized with (when -// first obtaining a handle to the local XLA service). If this is called after -// the handle to the local XLA service has been established, then an error is -// returned. -Status InitializePlatformName(const string& platform_name); - -// Returns the replica count that is currently set, regardless of whether the -// local XLA service has been instantiated yet or not. -int GetReplicaCount(); - -// Wraps the local client's infeed-transfer function. -// -// The default device ordinal (0) is used. -Status TransferToInfeedLocal(const Literal& literal); - -// Transfers the given literal to the infeed of the given replica. -// -// The replica number is resolved to an appropriate device ordinal. -Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number); - -// Transfers a literal of the given shape from the outfeed of the given replica. -// -// The replica number is resolved to an appropriate device ordinal. -StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, - int replica_number); +// Registers a 'fn_capsule' as a CPU custom call target. +// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name +// "xla._CPU_CUSTOM_CALL_TARGET". +Status RegisterCpuCustomCallTarget(const string& name, PyObject* fn_capsule); + +// Wrapper around an xla::LocalClient. +class LocalClient { + public: + // Initializes a local XLA client for `platform_name`. Returns an error if no + /// such platform exists, or if the platform has no visible devices. + static StatusOr Get(const string& platform_name); + + // Copyable and moveable; the class is just a wrapper around a + // xla::LocalClient pointer for convenient SWIG wrapping. + + // Returns the number of devices known to the XLA client. + int DeviceCount() const; + + // Wraps the local client's infeed-transfer function. + // + // The default device ordinal (0) is used. + Status TransferToInfeed(const Literal& literal, int device_ordinal); + + // Transfers a literal of the given shape from the outfeed of the given + // replica. + StatusOr TransferFromOutfeed(const Shape& shape, int device_ordinal); + + xla::LocalClient* client() const { return client_; } + + private: + LocalClient(xla::LocalClient* client); + + xla::LocalClient* client_; +}; + +class LocalShapedBufferTuple; // Represents a reference to literals that live in a device-allocated buffer via // XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a @@ -72,9 +76,9 @@ class LocalShapedBuffer { public: static StatusOr FromLiteral( const Literal& argument, const absl::optional& shape_with_layout, - int replica_number); + const LocalClient& client, int device_ordinal); - LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); + LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, xla::LocalClient* client); StatusOr ToLiteral() const; const Shape& shape() const; const ScopedShapedBuffer* shaped_buffer() const; @@ -83,8 +87,13 @@ class LocalShapedBuffer { // analogous to std::unique_ptr::release(). ShapedBuffer Release(); + // Destructures a tuple-valued LocalShapedBuffer into its constitutent + // elements in LocalShapedBufferTuple form. + StatusOr DestructureTuple(); + private: ScopedShapedBuffer shaped_buffer_; + xla::LocalClient* client_; }; // Result of a tuple destructuring operation on a LocalShapedBuffer -- this @@ -110,68 +119,20 @@ class LocalShapedBufferTuple { std::vector elements_; }; -// Destructures a tuple-valued LocalShapedBuffer into its constitutent elements -// in LocalShapedBufferTuple form. -StatusOr DestructureLocalShapedBufferTuple( - LocalShapedBuffer* local_shaped_buffer); - -// Represents a reference to literals that live in a device-allocated buffer via -// XRT. Specifically, wraps an int64 handle produced by running the allocation -// graph, and an XLA shape to track the referent's shape. -class XrtAllocation { - public: - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which allocation and deallocation - // graphs are run. - static StatusOr FromLiteral(const Literal& argument, - const string& session_target); - - XrtAllocation(int64 handle, Shape shape, const string& session_target); - ~XrtAllocation(); - StatusOr ToLiteral() const; - const Shape& shape() const; - const int64 handle() const; - - private: - const int64 handle_; - const Shape shape_; - const string session_target_; -}; - -// Result of a tuple destructuring operation on an XrtAllocation. -class XrtAllocationTuple { - public: - // Note: any XrtAllocation elements that are not Release()'d will be - // deallocated in the destructor. - explicit XrtAllocationTuple(std::vector elements); - - ~XrtAllocationTuple(); - - // Releases the ith element to the caller. Further attempts to release the ith - // element will return an invalid argument error. - StatusOr Release(int i); - - // Returns the number of elements in the destructured tuple. - int64 size() const; - - private: - std::vector elements_; -}; - -// Destructures a tuple-valued XrtAllocation into its constitutent elements -// in XrtAllocationTuple form. -// -// Accepts a `session_target` argument, used in constructing the -// `tensorflow::ClientSession` instance in which the sub-tupling graph is run, -// and passed along in constructing each constituent XrtAllocation. -StatusOr DestructureXrtAllocationTuple( - XrtAllocation* allocation, const string& session_target); - // Represents a compiled computation that can be executed given handles to // device-allocated literals. Specifically, wraps an XLA LocalExecutable. -class CompiledLocalComputation { +class LocalExecutable { public: - CompiledLocalComputation(std::unique_ptr executable); + LocalExecutable(std::unique_ptr executable, + xla::DeviceAssignment device_assignment, + xla::LocalClient* client); + + int num_replicas() const { + return executable_->build_options().num_replicas(); + } + + // Returns the device ordinals to which each replica is assigned. + std::vector DeviceOrdinals() const; StatusOr Execute( absl::Span argument_handles); @@ -183,47 +144,22 @@ class CompiledLocalComputation { absl::Span > argument_handles); private: - std::unique_ptr executable_; -}; - -// Represents a compiled computation that can be executed given handles to -// device-allocated literals. Specifically, wraps an XRT computation handle. -class CompiledXrtComputation { - public: - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which the execution graph is run. - CompiledXrtComputation(const ProgramShape& program_shape, int64 handle, - const string& session_target); - ~CompiledXrtComputation(); - - StatusOr Execute( - absl::Span argument_handles); - - const ProgramShape& program_shape() const; - int64 handle() const; - - private: - const ProgramShape program_shape_; - const int64 handle_; - const string session_target_; + const std::unique_ptr executable_; + const xla::DeviceAssignment device_assignment_; + xla::LocalClient* const client_; }; -// Wraps a XlaComputation produced by a LocalComputationBuilder. The +// Wraps a XlaComputation produced by a ComputationBuilder. The // Compile method compiles the computation to a (local) executable via // the client library's local client. This class is intended to be // made available to Python via SWIG. -class LocalComputation { +class Computation { public: - LocalComputation(XlaComputation computation); + Computation(XlaComputation computation); - StatusOr Compile( + StatusOr Compile( const std::vector& argument_shapes, - const ExecutableBuildOptions* build_options); - - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which the compilation graph is run. - StatusOr CompileForXrt( - const std::vector& argument_shapes, const string& session_target); + const ExecutableBuildOptions* build_options, const LocalClient& client); const XlaComputation& computation() const; @@ -232,6 +168,15 @@ class LocalComputation { // string on failure. string GetSerializedProto() const; + // Returns the computation in human-readable HLO text format. + StatusOr GetHloText() const; + + // Returns the computation in graphviz dot format. + StatusOr GetHloDotGraph() const; + + // Returns the program shape for this computation. + StatusOr GetProgramShape() const; + // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; @@ -239,7 +184,7 @@ class LocalComputation { XlaComputation computation_; }; -// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended +// Wraps a XlaOp produced by a ComputationBuilder. This class is intended // to be made available to Python via SWIG. class LocalOp { public: @@ -256,20 +201,20 @@ class LocalOp { // Python. // - Set up the underlying builder to use the client library's // LocalClient. -// - Wrap Computations in LocalComputations for Python access. -// - Correspondingly unwrap incoming LocalComputations. -class LocalComputationBuilder { +// - Wrap Computations in Computations for Python access. +// - Correspondingly unwrap incoming Computations. +class ComputationBuilder { public: - LocalComputationBuilder(const string& computation_name); + ComputationBuilder(const string& computation_name); void SetOpMetadata(const OpMetadata& metadata); void ClearOpMetadata(); - // Returns an owned LocalComputation to the caller on success. - StatusOr Build(); + // Returns an owned Computation to the caller on success. + StatusOr Build(); - // Returns an owned LocalComputation to the caller on success with given root. - StatusOr BuildWithRoot(const LocalOp& root); + // Returns an owned Computation to the caller on success with given root. + StatusOr BuildWithRoot(const LocalOp& root); LocalOp Parameter(int64 parameter_number, const Shape& shape, const string& name); @@ -286,6 +231,10 @@ class LocalComputationBuilder { LocalOp ConstantLiteral(const Literal& literal); + LocalOp Iota(PrimitiveType element_type, int64 size); + + LocalOp BroadcastedIota(const Shape& shape, int64 dimension); + LocalOp Broadcast(const LocalOp& operand, absl::Span broadcast_sizes); @@ -301,7 +250,12 @@ class LocalComputationBuilder { LocalOp Collapse(const LocalOp& operand, absl::Span dimensions); - LocalOp CrossReplicaSum(const LocalOp& operand); + LocalOp AllToAll(const LocalOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + absl::Span replica_groups); + + LocalOp CrossReplicaSum(const LocalOp& operand, + absl::Span replica_groups); LocalOp Slice(const LocalOp& operand, absl::Span start_indices, absl::Span limit_indices, @@ -319,11 +273,11 @@ class LocalComputationBuilder { LocalOp ConcatInDim(absl::Span operands, int64 dimension); LocalOp SelectAndScatterWithGeneralPadding( - const LocalOp& operand, const LocalComputation& select, + const LocalOp& operand, const Computation& select, absl::Span window_dimensions, absl::Span window_strides, absl::Span > padding, const LocalOp& source, - const LocalOp& init_value, const LocalComputation& scatter); + const LocalOp& init_value, const Computation& scatter); LocalOp Tuple(absl::Span elements); @@ -349,25 +303,31 @@ class LocalComputationBuilder { LocalOp BitcastConvertType(const LocalOp& operand, PrimitiveType new_element_type); - LocalOp Call(const LocalComputation& local_computation, + LocalOp Call(const Computation& local_computation, absl::Span operands); + LocalOp CustomCall(const string& call_target_name, + absl::Span operands, + const Shape& shape_with_layout, + const std::vector& operand_shapes_with_layout, + const string& opaque); + LocalOp Transpose(const LocalOp& operand, absl::Span permutation); LocalOp Rev(const LocalOp& operand, absl::Span dimensions); LocalOp Map(absl::Span operands, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions); LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions_to_reduce); LocalOp ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, @@ -379,13 +339,13 @@ class LocalComputationBuilder { LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); - LocalOp While(const LocalComputation& condition, const LocalComputation& body, + LocalOp While(const Computation& condition, const Computation& body, const LocalOp& init); LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, - const LocalComputation& true_computation, + const Computation& true_computation, const LocalOp& false_operand, - const LocalComputation& false_computation); + const Computation& false_computation); StatusOr IsConstant(const LocalOp& operand); @@ -394,7 +354,25 @@ class LocalComputationBuilder { LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values, int64 dimension); - StatusOr BuildConstantSubGraph(const LocalOp& operand); + LocalOp QR(const LocalOp& a, bool full_matrices); + + LocalOp Cholesky(const LocalOp& a); + + // `transpose_a` is the integer value of a TriangularSolveOptions::Transpose + // enum. We use an integer here so we don't have to teach SWIG about the + // enum. + LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side, + bool lower, bool unit_diagonal, int transpose_a); + + LocalOp Gather(const LocalOp& input, const LocalOp& start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes); + + LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices, + const LocalOp& updates, const Computation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); + + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; @@ -436,6 +414,7 @@ class LocalComputationBuilder { _FORWARD_BINOP(Pow) _FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) + _FORWARD_UNOP(Clz) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) _FORWARD_UNOP(Expm1) @@ -483,10 +462,8 @@ class LocalComputationBuilder { // Functions for freeing resources from the Python side. void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer); -void DeleteXrtAllocation(XrtAllocation* allocation); -void DeleteCompiledLocalComputation(CompiledLocalComputation* computation); -void DeleteCompiledXrtComputation(CompiledXrtComputation* computation); -void DeleteLocalComputation(LocalComputation* computation); +void DeleteLocalExecutable(LocalExecutable* computation); +void DeleteComputation(Computation* computation); } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index d23d693c1e5bde43b52959e4397aa311268411bb..7d7a860baa03e99cc254b7596fb5f9d41acbef20 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -23,17 +23,22 @@ limitations under the License. // C++ Python // -------------------------------------+--------------------------------------- // Span <- sequence of int +// vector -> sequence of int // Span <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) // <- object duck-typed as xla_client.Shape +// ProgramShape -> pair of ([arg_shapes], ret_shape) // std::vector <- sequence of xla_client.Shape objects // PrimitiveType <- int // Span> <- sequence of int pairs // PaddingConfig proto <- corresponding Python proto // ConvolutionDimensionNumbers proto <- corresponding Python proto // DotDimensionNumbers proto <- corresponding Python proto +// GatherDimensionNumbers proto <- corresponding Python proto +// ScatterDimensionNumbers proto <- corresponding Python proto +// Span <- sequence of ReplicaGroup Python proto // // Arrows indicate whether a conversion only ever occurs in one // direction, or whether it is maintained bidirectionally. @@ -94,7 +99,7 @@ limitations under the License. // wrapped in a Python class (xla_client.Shape) so as not to expose // the raw pair externally. // -// Other SWIG object wrappers (e.g. of LocalComputation) are further +// Other SWIG object wrappers (e.g. of Computation) are further // wrapped by xla_client in order to set up a custom destructor that // triggers memory deallocation on the C++ side. @@ -104,6 +109,7 @@ limitations under the License. %nothread; %include "tensorflow/python/platform/base.i" +%include "tensorflow/compiler/xla/python/xla_data.i" %{ // Must be included first @@ -121,54 +127,6 @@ limitations under the License. using namespace xla; using namespace xla::swig; -namespace xla { - -namespace swig { - -bool GetIntAttr(PyObject* o, const char* field, int64* result) { - PyObject* fo = PyObject_GetAttrString(o, field); - if (!fo) { - return false; - } - const int64 value = numpy::PyIntOrPyLongToLong(fo); - if (value == -1 && PyErr_Occurred()) { - Py_DECREF(fo); - return false; - } - Py_DECREF(fo); - *result = value; - return true; -} - -// Returns "ok"; true if there is no error, false if there was an error. -bool HandleStringAttribute(PyObject* o, - const char* attr_name, - std::function f) { - if (!PyObject_HasAttrString(o, attr_name)) { - return true; // It's ok for the object to not have the attribute. - } - PyObject* attr = PyObject_GetAttrString(o, attr_name); - if (attr == nullptr) { - return false; // An error occurred getting the attribute. - } - if (attr == Py_None) { - Py_DECREF(attr); - return true; // The attribute is None, which we consider ok. - } - if (!PyString_Check(attr)) { - string message = absl::StrFormat("%s must be a string or none; got %s", - attr_name, numpy::PyObjectCppRepr(attr)); - PyErr_SetString(PyExc_TypeError, message.c_str()); - Py_DECREF(attr); - return false; // Type error, not ok. - } - f(PyString_AsString(attr)); - Py_DECREF(attr); - return true; // Handled string attribute, ok! -} - -} -} %} // Required to use PyArray_* functions. @@ -176,57 +134,6 @@ bool HandleStringAttribute(PyObject* o, tensorflow::ImportNumpy(); %} -// Basic types - -%typemap(out) StatusOr { - if ($1.ok()) { - $result = PyBool_FromLong($1.ConsumeValueOrDie()); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) Status { - if (!$1.ok()) { - PyErr_SetString( - PyExc_RuntimeError, $1.ToString().c_str()); - SWIG_fail; - } - Py_INCREF(Py_None); - $result = Py_None; -} - -%typemap(in) absl::Span - (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.resize(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - PyObject* py_int = numpy::PyNumberToPyInt(o); - if (!py_int) { - PyErr_SetString( - PyExc_TypeError, - "Argument sequence element cannot be converted to int"); - Py_DECREF(o); - SWIG_fail; - } - temps[i] = numpy::PyIntOrPyLongToLong(py_int); - if (temps[i] == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - Py_DECREF(o); - SWIG_fail; - } - Py_DECREF(py_int); - Py_DECREF(o); - } - $1 = temps; -} - // Computation builder types %typemap(in) absl::Span( @@ -251,12 +158,12 @@ tensorflow::ImportNumpy(); // Computation and buffer/allocation types -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { - auto* value = $1.ValueOrDie(); + xla::swig::LocalClient value = $1.ValueOrDie(); { - auto* $1 = value; - $typemap(out, xla::swig::CompiledLocalComputation*) + auto $1 = value; + $typemap(out, xla::swig::LocalClient) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -264,12 +171,12 @@ tensorflow::ImportNumpy(); } } -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::CompiledXrtComputation*) + $typemap(out, xla::swig::LocalExecutable*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -303,38 +210,12 @@ tensorflow::ImportNumpy(); } } -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::XrtAllocation*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::XrtAllocationTuple*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::LocalComputation*) + $typemap(out, xla::swig::Computation*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -394,556 +275,6 @@ tensorflow::ImportNumpy(); $1 = temps; } -%typemap(in) absl::Span - (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - XrtAllocation* xrta; - if ((SWIG_ConvertPtr(o, (void**) &xrta, $descriptor(xla::swig::XrtAllocation*), - SWIG_POINTER_EXCEPTION)) == -1) { - SWIG_fail; - } - temps.push_back(xrta); - Py_DECREF(o); - } - $1 = temps; -} - -// Literal - -%typemap(out) StatusOr { - if ($1.ok()) { - Literal value = $1.ConsumeValueOrDie(); - $result = numpy::PyObjectFromXlaLiteral(*value); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(in) const Literal& (StatusOr literal_status) { - literal_status = numpy::XlaLiteralFromPyObject($input); - if (!literal_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); - SWIG_fail; - } - $1 = &literal_status.ValueOrDie(); -} - -%typemap(out) Literal { - $result = numpy::PyObjectFromXlaLiteral(*$1); -} - -%typemap(out) StatusOr { - if (!$1.ok()) { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } - $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); -} - -%typemap(in) const std::vector& (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - StatusOr literal_status = numpy::XlaLiteralFromPyObject(o); - if (!literal_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); - Py_DECREF(o); - SWIG_fail; - } - temps.push_back(literal_status.ConsumeValueOrDie()); - Py_DECREF(o); - } - $1 = &temps; -} - -// OpMetadata - -%typemap(in) const OpMetadata& (OpMetadata temp) { - StatusOr statusor = numpy::OpMetadataFromPyObject($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; -} - -// Shape - -%typemap(out) const Shape& { - $result = numpy::PyShapeInfoFromXlaShape(*$1); -} - -%typemap(out) StatusOr { - if ($1.ok()) { - $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(in) const Shape& (Shape temp) { - StatusOr statusor = numpy::XlaShapeFromPyShape($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; -} - -%typemap(in) const absl::optional& ( - absl::optional temp) { - if ($input == Py_None) { - temp = absl::nullopt; - $1 = &temp; - } else { - StatusOr statusor = numpy::XlaShapeFromPyShape($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; - } -} - -%typemap(out) std::unique_ptr { - $result = numpy::PyShapeInfoFromXlaShape(*$1); -} - -%typemap(in) const std::vector& (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - StatusOr statusor = numpy::XlaShapeFromPyShape(o); - Py_DECREF(o); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temps.push_back(statusor.ConsumeValueOrDie()); - } - $1 = &temps; -} - -%typemap(in) const std::vector >& ( - std::vector > temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - if (o == Py_None) { - temps.push_back(absl::nullopt); - } else { - StatusOr statusor = numpy::XlaShapeFromPyShape(o); - Py_DECREF(o); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temps.push_back(statusor.ConsumeValueOrDie()); - } - } - $1 = &temps; -} - -// PrimitiveType - -%typemap(in) PrimitiveType { - PyObject* py_int = numpy::PyNumberToPyInt($input); - if (!py_int) { - PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); - SWIG_fail; - } - const long value = numpy::PyIntOrPyLongToLong(py_int); - if (value == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - SWIG_fail; - } - if (!PrimitiveType_IsValid(value)) { - PyErr_SetString( - PyExc_TypeError, "Argument not valid for PrimitiveType enum"); - Py_DECREF(py_int); - SWIG_fail; - } - $1 = static_cast(value); -} - -// Span> - -%typemap(in) absl::Span > - (std::vector > temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - if (!o) { - SWIG_fail; - } - PyObject* first = PyTuple_GetItem(o, 0); - if (!first) { - Py_DECREF(o); - SWIG_fail; - } - PyObject* first_pyint = numpy::PyNumberToPyInt(first); - if (!first_pyint) { - PyErr_SetString( - PyExc_TypeError, - "First pair item cannot be converted to int"); - Py_DECREF(o); - SWIG_fail; - } - PyObject* second = PyTuple_GetItem(o, 1); - if (!second) { - Py_DECREF(o); - Py_DECREF(first_pyint); - SWIG_fail; - } - PyObject* second_pyint = numpy::PyNumberToPyInt(second); - if (!second_pyint) { - PyErr_SetString( - PyExc_TypeError, - "Second pair item cannot be converted to int"); - Py_DECREF(o); - Py_DECREF(first_pyint); - SWIG_fail; - } - const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); - if (first_value == -1 && PyErr_Occurred()) { - Py_DECREF(o); - Py_DECREF(first_pyint); - Py_DECREF(second_pyint); - SWIG_fail; - } - const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); - if (second_value == -1 && PyErr_Occurred()) { - Py_DECREF(o); - Py_DECREF(first_pyint); - Py_DECREF(second_pyint); - SWIG_fail; - } - temps.push_back(std::make_pair(first_value, second_value)); - Py_DECREF(o); - } - $1 = temps; -} - -// DotDimensionNumbers - -%typemap(in) const DotDimensionNumbers& - (DotDimensionNumbers dimension_numbers) { - int length; - - /* lhs_contracting_dimensions */ - PyObject* lhs_contracting_dimensions = PyObject_GetAttrString( - $input, "lhs_contracting_dimensions"); - if (!lhs_contracting_dimensions) { - SWIG_fail; - } - - length = PySequence_Size(lhs_contracting_dimensions); - if (length == -1) { - Py_DECREF(lhs_contracting_dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i); - if (!item) { - Py_DECREF(lhs_contracting_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(lhs_contracting_dimensions); - SWIG_fail; - } - dimension_numbers.add_lhs_contracting_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(lhs_contracting_dimensions); - - /* rhs_contracting_dimensions */ - PyObject* rhs_contracting_dimensions = PyObject_GetAttrString( - $input, "rhs_contracting_dimensions"); - if (!lhs_contracting_dimensions) { - SWIG_fail; - } - - length = PySequence_Size(rhs_contracting_dimensions); - if (length == -1) { - Py_DECREF(rhs_contracting_dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i); - if (!item) { - Py_DECREF(rhs_contracting_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(rhs_contracting_dimensions); - SWIG_fail; - } - dimension_numbers.add_rhs_contracting_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(rhs_contracting_dimensions); - - /* lhs_batch_dimensions */ - PyObject* lhs_batch_dimensions = PyObject_GetAttrString( - $input, "lhs_batch_dimensions"); - if (!lhs_batch_dimensions) { - SWIG_fail; - } - - length = PySequence_Size(lhs_batch_dimensions); - if (length == -1) { - Py_DECREF(lhs_batch_dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i); - if (!item) { - Py_DECREF(lhs_batch_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(lhs_batch_dimensions); - SWIG_fail; - } - dimension_numbers.add_lhs_batch_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(lhs_batch_dimensions); - - /* rhs_batch_dimensions */ - PyObject* rhs_batch_dimensions = PyObject_GetAttrString( - $input, "rhs_batch_dimensions"); - if (!rhs_batch_dimensions) { - SWIG_fail; - } - - length = PySequence_Size(rhs_batch_dimensions); - if (length == -1) { - Py_DECREF(rhs_batch_dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i); - if (!item) { - Py_DECREF(rhs_batch_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(rhs_batch_dimensions); - SWIG_fail; - } - dimension_numbers.add_rhs_batch_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(rhs_batch_dimensions); - - $1 = &dimension_numbers; -} - -// PaddingConfig - -%typemap(in) const PaddingConfig& - (PaddingConfig padding_config) { - PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); - if (!dimensions) { - SWIG_fail; - } - - int length = PySequence_Size(dimensions); - if (length == -1) { - Py_DECREF(dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(dimensions, i); - if (!item) { - Py_DECREF(dimensions); - SWIG_fail; - } - int64 edge_padding_low, edge_padding_high, interior_padding; - if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) - || !GetIntAttr(item, "edge_padding_high", &edge_padding_high) - || !GetIntAttr(item, "interior_padding", &interior_padding)) { - Py_DECREF(item); - Py_DECREF(dimensions); - SWIG_fail; - } - Py_DECREF(item); - - PaddingConfig::PaddingConfigDimension* dimension = - padding_config.add_dimensions(); - dimension->set_edge_padding_low(edge_padding_low); - dimension->set_edge_padding_high(edge_padding_high); - dimension->set_interior_padding(interior_padding); - } - Py_DECREF(dimensions); - - $1 = &padding_config; -} - -// ConvolutionDimensionNumbers - -%typemap(in) const ConvolutionDimensionNumbers& - (ConvolutionDimensionNumbers dimension_numbers) { - int64 value; - - if (!GetIntAttr($input, "input_batch_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_input_batch_dimension(value); - - if (!GetIntAttr($input, "input_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_input_feature_dimension(value); - - if (!GetIntAttr($input, "output_batch_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_output_batch_dimension(value); - - if (!GetIntAttr($input, "output_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_output_feature_dimension(value); - - if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_kernel_output_feature_dimension(value); - - if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_kernel_input_feature_dimension(value); - - PyObject* o; - int length; - - o = PyObject_GetAttrString($input, "input_spatial_dimensions"); - if (!o) { - SWIG_fail; - } - length = PySequence_Size(o); - if (length == -1) { - Py_DECREF(o); - SWIG_fail; - } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(o, i); - if (!item) { - Py_DECREF(o); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(o); - SWIG_fail; - } - dimension_numbers.add_input_spatial_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(o); - - o = PyObject_GetAttrString($input, "kernel_spatial_dimensions"); - if (!o) { - SWIG_fail; - } - length = PySequence_Size(o); - if (length == -1) { - Py_DECREF(o); - SWIG_fail; - } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(o, i); - if (!item) { - Py_DECREF(o); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(o); - SWIG_fail; - } - dimension_numbers.add_kernel_spatial_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(o); - - o = PyObject_GetAttrString($input, "output_spatial_dimensions"); - if (!o) { - SWIG_fail; - } - length = PySequence_Size(o); - if (length == -1) { - Py_DECREF(o); - SWIG_fail; - } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(o, i); - if (!item) { - Py_DECREF(o); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(o); - SWIG_fail; - } - dimension_numbers.add_output_spatial_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(o); - - $1 = &dimension_numbers; -} - // ExecutableBuildOptions %typemap(in) const ExecutableBuildOptions* @@ -1000,6 +331,12 @@ tensorflow::ImportNumpy(); } Py_DECREF(o); + int64 num_replicas; + if (!GetIntAttr($input, "num_replicas", &num_replicas)) { + SWIG_fail; + } + build_options.set_num_replicas(num_replicas); + $1 = &build_options; } } @@ -1007,150 +344,151 @@ tensorflow::ImportNumpy(); %ignoreall %unignore xla; %unignore xla::swig; -%unignore xla::swig::InitializeReplicaCount; -%unignore xla::swig::InitializePlatformName; -%unignore xla::swig::GetReplicaCount; -%unignore xla::swig::TransferToInfeedLocal; -%unignore xla::swig::TransferToInfeedLocalReplica; -%unignore xla::swig::TransferFromOutfeedLocalReplica; +%unignore xla::swig::RegisterCpuCustomCallTarget; +%unignore xla::swig::LocalClient; +%unignore xla::swig::LocalClient::Get; +%unignore xla::swig::LocalClient::DeviceCount; +%unignore xla::swig::LocalClient::TransferToInfeed; +%unignore xla::swig::LocalClient::TransferFromOutfeed; %unignore xla::swig::LocalShapedBuffer; %unignore xla::swig::LocalShapedBuffer::FromLiteral; %unignore xla::swig::LocalShapedBuffer::ToLiteral; %unignore xla::swig::LocalShapedBuffer::shape; +%unignore xla::swig::LocalShapedBuffer::DestructureTuple; %unignore xla::swig::LocalShapedBufferTuple; %unignore xla::swig::LocalShapedBufferTuple::Release; %unignore xla::swig::LocalShapedBufferTuple::size; -%unignore xla::swig::XrtAllocation; -%unignore xla::swig::XrtAllocation::FromLiteral; -%unignore xla::swig::XrtAllocation::ToLiteral; -%unignore xla::swig::XrtAllocation::shape; -%unignore xla::swig::XrtAllocationTuple; -%unignore xla::swig::XrtAllocationTuple::Release; -%unignore xla::swig::XrtAllocationTuple::size; -%unignore xla::swig::CompiledLocalComputation; -%unignore xla::swig::CompiledLocalComputation::Execute; -%unignore xla::swig::CompiledLocalComputation::ExecutePerReplica; -%unignore xla::swig::CompiledXrtComputation; -%unignore xla::swig::CompiledXrtComputation::Execute; -%unignore xla::swig::LocalComputation; -%unignore xla::swig::LocalComputation::Compile; -%unignore xla::swig::LocalComputation::CompileForXrt; -%unignore xla::swig::LocalComputation::GetReturnValueShape; -%unignore xla::swig::LocalComputation::GetSerializedProto; +%unignore xla::swig::LocalExecutable; +%unignore xla::swig::LocalExecutable::DeviceOrdinals; +%unignore xla::swig::LocalExecutable::Execute; +%unignore xla::swig::LocalExecutable::ExecutePerReplica; +%unignore xla::swig::Computation; +%unignore xla::swig::Computation::Compile; +%unignore xla::swig::Computation::GetProgramShape; +%unignore xla::swig::Computation::GetReturnValueShape; +%unignore xla::swig::Computation::GetSerializedProto; +%unignore xla::swig::Computation::GetHloText; +%unignore xla::swig::Computation::GetHloDotGraph; %unignore xla::swig::LocalOp; -%unignore xla::swig::LocalComputationBuilder; -%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; -%unignore xla::swig::LocalComputationBuilder::Build; -%unignore xla::swig::LocalComputationBuilder::BuildWithRoot; -%unignore xla::swig::LocalComputationBuilder::SetOpMetadata; -%unignore xla::swig::LocalComputationBuilder::ClearOpMetadata; -%unignore xla::swig::LocalComputationBuilder::Parameter; -%unignore xla::swig::LocalComputationBuilder::GetShape; -%unignore xla::swig::LocalComputationBuilder::GetReturnValueShape; -%unignore xla::swig::LocalComputationBuilder::Infeed; -%unignore xla::swig::LocalComputationBuilder::Outfeed; -%unignore xla::swig::LocalComputationBuilder::ConstantLiteral; -%unignore xla::swig::LocalComputationBuilder::ConstantR0; -%unignore xla::swig::LocalComputationBuilder::Broadcast; -%unignore xla::swig::LocalComputationBuilder::BroadcastInDim; -%unignore xla::swig::LocalComputationBuilder::Pad; -%unignore xla::swig::LocalComputationBuilder::Reshape; -%unignore xla::swig::LocalComputationBuilder::Collapse; -%unignore xla::swig::LocalComputationBuilder::CrossReplicaSum; -%unignore xla::swig::LocalComputationBuilder::Slice; -%unignore xla::swig::LocalComputationBuilder::SliceInDim; -%unignore xla::swig::LocalComputationBuilder::DynamicSlice; -%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice; -%unignore xla::swig::LocalComputationBuilder::ConcatInDim; -%unignore xla::swig::LocalComputationBuilder::SelectAndScatterWithGeneralPadding; -%unignore xla::swig::LocalComputationBuilder::Select; -%unignore xla::swig::LocalComputationBuilder::Tuple; -%unignore xla::swig::LocalComputationBuilder::GetTupleElement; -%unignore xla::swig::LocalComputationBuilder::ConvertElementType; -%unignore xla::swig::LocalComputationBuilder::BitcastConvertType; -%unignore xla::swig::LocalComputationBuilder::Call; -%unignore xla::swig::LocalComputationBuilder::Transpose; -%unignore xla::swig::LocalComputationBuilder::Rev; -%unignore xla::swig::LocalComputationBuilder::Clamp; -%unignore xla::swig::LocalComputationBuilder::Map; -%unignore xla::swig::LocalComputationBuilder::Reduce; -%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding; -%unignore xla::swig::LocalComputationBuilder::RngNormal; -%unignore xla::swig::LocalComputationBuilder::RngUniform; -%unignore xla::swig::LocalComputationBuilder::RngBernoulli; -%unignore xla::swig::LocalComputationBuilder::While; -%unignore xla::swig::LocalComputationBuilder::Conditional; -%unignore xla::swig::LocalComputationBuilder::IsConstant; -%unignore xla::swig::LocalComputationBuilder::Eq; -%unignore xla::swig::LocalComputationBuilder::Ne; -%unignore xla::swig::LocalComputationBuilder::Ge; -%unignore xla::swig::LocalComputationBuilder::Gt; -%unignore xla::swig::LocalComputationBuilder::Lt; -%unignore xla::swig::LocalComputationBuilder::Le; -%unignore xla::swig::LocalComputationBuilder::Dot; -%unignore xla::swig::LocalComputationBuilder::DotGeneral; -%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated; -%unignore xla::swig::LocalComputationBuilder::Add; -%unignore xla::swig::LocalComputationBuilder::Sub; -%unignore xla::swig::LocalComputationBuilder::Mul; -%unignore xla::swig::LocalComputationBuilder::Div; -%unignore xla::swig::LocalComputationBuilder::Rem; -%unignore xla::swig::LocalComputationBuilder::Max; -%unignore xla::swig::LocalComputationBuilder::Min; -%unignore xla::swig::LocalComputationBuilder::And; -%unignore xla::swig::LocalComputationBuilder::Or; -%unignore xla::swig::LocalComputationBuilder::Xor; -%unignore xla::swig::LocalComputationBuilder::ShiftLeft; -%unignore xla::swig::LocalComputationBuilder::ShiftRightArithmetic; -%unignore xla::swig::LocalComputationBuilder::ShiftRightLogical; -%unignore xla::swig::LocalComputationBuilder::Not; -%unignore xla::swig::LocalComputationBuilder::Abs; -%unignore xla::swig::LocalComputationBuilder::Exp; -%unignore xla::swig::LocalComputationBuilder::Expm1; -%unignore xla::swig::LocalComputationBuilder::Floor; -%unignore xla::swig::LocalComputationBuilder::Ceil; -%unignore xla::swig::LocalComputationBuilder::Round; -%unignore xla::swig::LocalComputationBuilder::Log; -%unignore xla::swig::LocalComputationBuilder::Log1p; -%unignore xla::swig::LocalComputationBuilder::Sign; -%unignore xla::swig::LocalComputationBuilder::Cos; -%unignore xla::swig::LocalComputationBuilder::Sin; -%unignore xla::swig::LocalComputationBuilder::Tanh; -%unignore xla::swig::LocalComputationBuilder::Atan2; -%unignore xla::swig::LocalComputationBuilder::IsFinite; -%unignore xla::swig::LocalComputationBuilder::Pow; -%unignore xla::swig::LocalComputationBuilder::Neg; -%unignore xla::swig::LocalComputationBuilder::Sort; -%unignore xla::swig::LocalComputationBuilder::SortKeyVal; -%unignore xla::swig::LocalComputationBuilder::Sqrt; -%unignore xla::swig::LocalComputationBuilder::Rsqrt; -%unignore xla::swig::LocalComputationBuilder::Square; -%unignore xla::swig::LocalComputationBuilder::Reciprocal; -%unignore xla::swig::LocalComputationBuilder::Erfc; -%unignore xla::swig::LocalComputationBuilder::Erf; -%unignore xla::swig::LocalComputationBuilder::ErfInv; -%unignore xla::swig::LocalComputationBuilder::Lgamma; -%unignore xla::swig::LocalComputationBuilder::Digamma; -%unignore xla::swig::LocalComputationBuilder::Acos; -%unignore xla::swig::LocalComputationBuilder::Asin; -%unignore xla::swig::LocalComputationBuilder::Atan; -%unignore xla::swig::LocalComputationBuilder::Tan; -%unignore xla::swig::LocalComputationBuilder::Acosh; -%unignore xla::swig::LocalComputationBuilder::Asinh; -%unignore xla::swig::LocalComputationBuilder::Atanh; -%unignore xla::swig::LocalComputationBuilder::Cosh; -%unignore xla::swig::LocalComputationBuilder::Sinh; -%unignore xla::swig::LocalComputationBuilder::Real; -%unignore xla::swig::LocalComputationBuilder::Imag; -%unignore xla::swig::LocalComputationBuilder::Conj; -%unignore xla::swig::LocalComputationBuilder::Complex; -%unignore xla::swig::DeleteLocalComputation; -%unignore xla::swig::DestructureLocalShapedBufferTuple; -%unignore xla::swig::DestructureXrtAllocationTuple; +%unignore xla::swig::ComputationBuilder; +%unignore xla::swig::ComputationBuilder::ComputationBuilder; +%unignore xla::swig::ComputationBuilder::Build; +%unignore xla::swig::ComputationBuilder::BuildWithRoot; +%unignore xla::swig::ComputationBuilder::SetOpMetadata; +%unignore xla::swig::ComputationBuilder::ClearOpMetadata; +%unignore xla::swig::ComputationBuilder::Parameter; +%unignore xla::swig::ComputationBuilder::GetShape; +%unignore xla::swig::ComputationBuilder::GetReturnValueShape; +%unignore xla::swig::ComputationBuilder::Infeed; +%unignore xla::swig::ComputationBuilder::Outfeed; +%unignore xla::swig::ComputationBuilder::ConstantLiteral; +%unignore xla::swig::ComputationBuilder::ConstantR0; +%unignore xla::swig::ComputationBuilder::Iota; +%unignore xla::swig::ComputationBuilder::BroadcastedIota; +%unignore xla::swig::ComputationBuilder::Broadcast; +%unignore xla::swig::ComputationBuilder::BroadcastInDim; +%unignore xla::swig::ComputationBuilder::Pad; +%unignore xla::swig::ComputationBuilder::Reshape; +%unignore xla::swig::ComputationBuilder::Collapse; +%unignore xla::swig::ComputationBuilder::AllToAll; +%unignore xla::swig::ComputationBuilder::CrossReplicaSum; +%unignore xla::swig::ComputationBuilder::Slice; +%unignore xla::swig::ComputationBuilder::SliceInDim; +%unignore xla::swig::ComputationBuilder::DynamicSlice; +%unignore xla::swig::ComputationBuilder::DynamicUpdateSlice; +%unignore xla::swig::ComputationBuilder::ConcatInDim; +%unignore xla::swig::ComputationBuilder::SelectAndScatterWithGeneralPadding; +%unignore xla::swig::ComputationBuilder::Select; +%unignore xla::swig::ComputationBuilder::Tuple; +%unignore xla::swig::ComputationBuilder::GetTupleElement; +%unignore xla::swig::ComputationBuilder::ConvertElementType; +%unignore xla::swig::ComputationBuilder::BitcastConvertType; +%unignore xla::swig::ComputationBuilder::Call; +%unignore xla::swig::ComputationBuilder::Transpose; +%unignore xla::swig::ComputationBuilder::Rev; +%unignore xla::swig::ComputationBuilder::Clamp; +%unignore xla::swig::ComputationBuilder::Map; +%unignore xla::swig::ComputationBuilder::Reduce; +%unignore xla::swig::ComputationBuilder::ReduceWindowWithGeneralPadding; +%unignore xla::swig::ComputationBuilder::RngNormal; +%unignore xla::swig::ComputationBuilder::RngUniform; +%unignore xla::swig::ComputationBuilder::RngBernoulli; +%unignore xla::swig::ComputationBuilder::While; +%unignore xla::swig::ComputationBuilder::Conditional; +%unignore xla::swig::ComputationBuilder::IsConstant; +%unignore xla::swig::ComputationBuilder::Eq; +%unignore xla::swig::ComputationBuilder::Ne; +%unignore xla::swig::ComputationBuilder::Ge; +%unignore xla::swig::ComputationBuilder::Gt; +%unignore xla::swig::ComputationBuilder::Lt; +%unignore xla::swig::ComputationBuilder::Le; +%unignore xla::swig::ComputationBuilder::Dot; +%unignore xla::swig::ComputationBuilder::DotGeneral; +%unignore xla::swig::ComputationBuilder::ConvGeneralDilated; +%unignore xla::swig::ComputationBuilder::Add; +%unignore xla::swig::ComputationBuilder::Sub; +%unignore xla::swig::ComputationBuilder::Mul; +%unignore xla::swig::ComputationBuilder::Div; +%unignore xla::swig::ComputationBuilder::Rem; +%unignore xla::swig::ComputationBuilder::Max; +%unignore xla::swig::ComputationBuilder::Min; +%unignore xla::swig::ComputationBuilder::And; +%unignore xla::swig::ComputationBuilder::Or; +%unignore xla::swig::ComputationBuilder::Xor; +%unignore xla::swig::ComputationBuilder::ShiftLeft; +%unignore xla::swig::ComputationBuilder::ShiftRightArithmetic; +%unignore xla::swig::ComputationBuilder::ShiftRightLogical; +%unignore xla::swig::ComputationBuilder::Not; +%unignore xla::swig::ComputationBuilder::Clz; +%unignore xla::swig::ComputationBuilder::Abs; +%unignore xla::swig::ComputationBuilder::Exp; +%unignore xla::swig::ComputationBuilder::Expm1; +%unignore xla::swig::ComputationBuilder::Floor; +%unignore xla::swig::ComputationBuilder::Ceil; +%unignore xla::swig::ComputationBuilder::Round; +%unignore xla::swig::ComputationBuilder::Log; +%unignore xla::swig::ComputationBuilder::Log1p; +%unignore xla::swig::ComputationBuilder::Sign; +%unignore xla::swig::ComputationBuilder::Cos; +%unignore xla::swig::ComputationBuilder::Sin; +%unignore xla::swig::ComputationBuilder::Tanh; +%unignore xla::swig::ComputationBuilder::Atan2; +%unignore xla::swig::ComputationBuilder::IsFinite; +%unignore xla::swig::ComputationBuilder::Pow; +%unignore xla::swig::ComputationBuilder::Neg; +%unignore xla::swig::ComputationBuilder::Sort; +%unignore xla::swig::ComputationBuilder::SortKeyVal; +%unignore xla::swig::ComputationBuilder::Sqrt; +%unignore xla::swig::ComputationBuilder::Rsqrt; +%unignore xla::swig::ComputationBuilder::Square; +%unignore xla::swig::ComputationBuilder::Reciprocal; +%unignore xla::swig::ComputationBuilder::Erfc; +%unignore xla::swig::ComputationBuilder::Erf; +%unignore xla::swig::ComputationBuilder::ErfInv; +%unignore xla::swig::ComputationBuilder::Lgamma; +%unignore xla::swig::ComputationBuilder::Digamma; +%unignore xla::swig::ComputationBuilder::Acos; +%unignore xla::swig::ComputationBuilder::Asin; +%unignore xla::swig::ComputationBuilder::Atan; +%unignore xla::swig::ComputationBuilder::Tan; +%unignore xla::swig::ComputationBuilder::Acosh; +%unignore xla::swig::ComputationBuilder::Asinh; +%unignore xla::swig::ComputationBuilder::Atanh; +%unignore xla::swig::ComputationBuilder::Cosh; +%unignore xla::swig::ComputationBuilder::Sinh; +%unignore xla::swig::ComputationBuilder::Real; +%unignore xla::swig::ComputationBuilder::Imag; +%unignore xla::swig::ComputationBuilder::Conj; +%unignore xla::swig::ComputationBuilder::Complex; +%unignore xla::swig::ComputationBuilder::Cholesky; +%unignore xla::swig::ComputationBuilder::QR; +%unignore xla::swig::ComputationBuilder::TriangularSolve; +%unignore xla::swig::ComputationBuilder::CustomCall; +%unignore xla::swig::ComputationBuilder::Gather; +%unignore xla::swig::ComputationBuilder::Scatter; +%unignore xla::swig::DeleteComputation; %unignore xla::swig::DeleteLocalShapedBuffer; -%unignore xla::swig::DeleteXrtAllocation; -%unignore xla::swig::DeleteCompiledLocalComputation; -%unignore xla::swig::DeleteCompiledXrtComputation; +%unignore xla::swig::DeleteLocalExecutable; %thread; %include "tensorflow/compiler/xla/python/local_computation_builder.h" diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index b0aa024c7474cf8e6934432b2f364be464714999..74f45b7cdcfd7d7b10a5832be37ac1fb34057743 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -26,6 +26,10 @@ namespace swig { namespace numpy { +Safe_PyObjectPtr make_safe(PyObject* object) { + return Safe_PyObjectPtr(object); +} + int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { switch (primitive_type) { case PRED: @@ -54,6 +58,8 @@ int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { return NPY_FLOAT64; case C64: return NPY_COMPLEX64; + case C128: + return NPY_COMPLEX128; case TUPLE: return NPY_OBJECT; default: @@ -89,6 +95,8 @@ PrimitiveType NumpyTypeToPrimitiveType(int np_type) { return F64; case NPY_COMPLEX64: return C64; + case NPY_COMPLEX128: + return C128; case NPY_OBJECT: return TUPLE; default: @@ -111,6 +119,7 @@ bool NumpyTypeIsValid(int np_type) { case NPY_FLOAT32: case NPY_FLOAT64: case NPY_COMPLEX64: + case NPY_COMPLEX128: case NPY_OBJECT: return true; default: @@ -118,28 +127,42 @@ bool NumpyTypeIsValid(int np_type) { } } -PyObject* PyShapeInfoFromXlaShape(const Shape& shape) { +Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape) { int np_typenum = PrimitiveTypeToNumpyType(shape.element_type()); PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum); - PyObject* dimensions; - if (ShapeUtil::IsTuple(shape)) { + Safe_PyObjectPtr dimensions; + if (shape.IsTuple()) { int num_elements = ShapeUtil::TupleElementCount(shape); - dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape)); + dimensions = make_safe(PyTuple_New(ShapeUtil::TupleElementCount(shape))); for (int i = 0; i < num_elements; ++i) { PyTuple_SET_ITEM( - dimensions, i, - PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))); + dimensions.get(), i, + PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i)) + .release()); } } else { - int rank = ShapeUtil::Rank(shape); - dimensions = PyTuple_New(rank); + int rank = shape.rank(); + dimensions = make_safe(PyTuple_New(rank)); for (int i = 0; i < rank; ++i) { - PyTuple_SET_ITEM(dimensions, i, + PyTuple_SET_ITEM(dimensions.get(), i, LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i))); } } - return PyTuple_Pack(2, np_dtype, dimensions); + return make_safe(PyTuple_Pack(2, np_dtype, dimensions.release())); +} + +Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape( + const ProgramShape& shape) { + Safe_PyObjectPtr arg_shapes = make_safe(PyTuple_New(shape.parameters_size())); + for (int i = 0; i < shape.parameters_size(); ++i) { + PyTuple_SET_ITEM(arg_shapes.get(), i, + PyShapeInfoFromXlaShape(shape.parameters(i)).release()); + } + + Safe_PyObjectPtr result_shape = PyShapeInfoFromXlaShape(shape.result()); + return make_safe( + PyTuple_Pack(2, arg_shapes.release(), result_shape.release())); } // Precondition: o->ob_type == &PyArrayDescr_Type @@ -344,26 +367,30 @@ StatusOr OpMetadataFromPyObject(PyObject* o) { return result; } -PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { - if (ShapeUtil::IsTuple(literal.shape())) { +StatusOr PyObjectFromXlaLiteral(const LiteralSlice& literal) { + if (literal.shape().IsTuple()) { int num_elements = ShapeUtil::TupleElementCount(literal.shape()); - PyObject* tuple = PyTuple_New(num_elements); + std::vector elems(num_elements); + for (int i = 0; i < num_elements; i++) { + TF_ASSIGN_OR_RETURN(elems[i], + PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); + } + Safe_PyObjectPtr tuple = make_safe(PyTuple_New(num_elements)); for (int i = 0; i < num_elements; i++) { - PyTuple_SET_ITEM(tuple, i, - PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); + PyTuple_SET_ITEM(tuple.get(), i, elems[i].release()); } return tuple; } else { - int rank = ShapeUtil::Rank(literal.shape()); + int rank = literal.shape().rank(); std::vector dimensions(rank); // NOLINT - PyArray requires a long* for (int i = 0; i < rank; i++) { dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i); } int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type()); - PyObject* array = - PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0); - CopyLiteralToNumpyArray(np_type, literal, - reinterpret_cast(array)); + Safe_PyObjectPtr array = make_safe( + PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0)); + TF_RETURN_IF_ERROR(CopyLiteralToNumpyArray( + np_type, literal, reinterpret_cast(array.get()))); return array; } } @@ -403,6 +430,12 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_BOOL: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_INT8: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_INT16: + CopyNumpyArrayToLiteral(py_array, literal); + break; case NPY_INT32: CopyNumpyArrayToLiteral(py_array, literal); break; @@ -412,6 +445,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_UINT8: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_UINT16: + CopyNumpyArrayToLiteral(py_array, literal); + break; case NPY_UINT32: CopyNumpyArrayToLiteral(py_array, literal); break; @@ -430,6 +466,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_COMPLEX64: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_COMPLEX128: + CopyNumpyArrayToLiteral(py_array, literal); + break; default: return InvalidArgument( "No XLA literal container for Numpy type number: %d", np_type); @@ -437,12 +476,18 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, return Status::OK(); } -void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, - PyArrayObject* py_array) { +Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, + PyArrayObject* py_array) { switch (np_type) { case NPY_BOOL: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_INT8: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_INT16: + CopyLiteralToNumpyArray(literal, py_array); + break; case NPY_INT32: CopyLiteralToNumpyArray(literal, py_array); break; @@ -452,6 +497,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, case NPY_UINT8: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_UINT16: + CopyLiteralToNumpyArray(literal, py_array); + break; case NPY_UINT32: CopyLiteralToNumpyArray(literal, py_array); break; @@ -470,9 +518,14 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, case NPY_COMPLEX64: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_COMPLEX128: + CopyLiteralToNumpyArray(literal, py_array); + break; default: - LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; + return InvalidArgument( + "No XLA literal container for Numpy type number: %d", np_type); } + return Status::OK(); } PyObject* LongToPyIntOrPyLong(long x) { // NOLINT @@ -514,6 +567,92 @@ PyObject* PyNumberToPyInt(PyObject* o) { } // namespace numpy +bool GetIntAttr(PyObject* o, const char* field, int64* result) { + PyObject* fo = PyObject_GetAttrString(o, field); + if (!fo) { + return false; + } + const int64 value = numpy::PyIntOrPyLongToLong(fo); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(fo); + return false; + } + Py_DECREF(fo); + *result = value; + return true; +} + +// Returns "ok"; true if there is no error, false if there was an error. +bool HandleStringAttribute(PyObject* o, const char* attr_name, + std::function f) { + if (!PyObject_HasAttrString(o, attr_name)) { + return true; // It's ok for the object to not have the attribute. + } + PyObject* attr = PyObject_GetAttrString(o, attr_name); + if (attr == nullptr) { + return false; // An error occurred getting the attribute. + } + if (attr == Py_None) { + Py_DECREF(attr); + return true; // The attribute is None, which we consider ok. + } +#if PY_MAJOR_VERSION < 3 + if (!PyString_Check(attr)) { + string message = absl::StrFormat("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr)); + PyErr_SetString(PyExc_TypeError, message.c_str()); + Py_DECREF(attr); + return false; // Type error, not ok. + } + f(PyString_AsString(attr)); +#else + if (!PyBytes_Check(attr)) { + string message = absl::StrFormat("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr)); + PyErr_SetString(PyExc_TypeError, message.c_str()); + Py_DECREF(attr); + return false; // Type error, not ok. + } + f(PyBytes_AsString(attr)); +#endif + + Py_DECREF(attr); + return true; // Handled string attribute, ok! +} + +bool HandleRepeatedInt64Attribute( + PyObject* o, const char* attr_name, + tensorflow::protobuf::RepeatedField* field) { + PyObject* seq = PyObject_GetAttrString(o, attr_name); + if (!seq) { + return false; + } + + int length = PySequence_Size(seq); + if (length == -1) { + Py_DECREF(seq); + return false; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(seq, i); + if (!item) { + Py_DECREF(seq); + return false; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(seq); + return false; + } + *field->Add() = dimension; + Py_DECREF(item); + } + Py_DECREF(seq); + return true; +} + } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 40ff2d9ad214cc4dcad42234fa296834cbc92882..eff8cda334f00050605febad66a61aa1c518c500 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -36,6 +36,16 @@ namespace swig { namespace numpy { +struct PyDecrefDeleter { + void operator()(PyObject* p) const { Py_DECREF(p); } +}; + +// Safe container for an owned PyObject. On destruction, the reference count of +// the contained object will be decremented. +using Safe_PyObjectPtr = std::unique_ptr; + +Safe_PyObjectPtr make_safe(PyObject* object); + // Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy // dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and // vice versa. @@ -54,7 +64,13 @@ bool NumpyTypeIsValid(int np_type); // providing the array dimensions. // // The return value is a new reference. -PyObject* PyShapeInfoFromXlaShape(const Shape& shape); +Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape); + +// Returns a pair of (arg_shapes, result_shape), where arg_shapes is a tuple +// of argument shapes and result_shape is the result shape. Each shape is as +// described in in PyShapeInfoFromXlaShape's comment. +Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape( + const ProgramShape& shape); // Converts a Python object with a method interface mathing that of // xla_client.Shape into an XLA Shape object. @@ -74,7 +90,7 @@ StatusOr OpMetadataFromPyObject(PyObject* o); // array data. // // The return value is a new reference. -PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); +StatusOr PyObjectFromXlaLiteral(const LiteralSlice& literal); // Converts a Numpy ndarray or a nested Python tuple thereof to a // corresponding XLA literal. @@ -90,8 +106,8 @@ StatusOr XlaLiteralFromPyObject(PyObject* o); Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, Literal* literal); -void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, - PyArrayObject* py_array); +Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, + PyArrayObject* py_array); template void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { @@ -120,6 +136,18 @@ PyObject* PyNumberToPyInt(PyObject* o); } // namespace numpy +// Miscellaneous swig helpers that don't have a better home. + +bool GetIntAttr(PyObject* o, const char* field, int64* result); + +// Returns "ok"; true if there is no error, false if there was an error. +bool HandleStringAttribute(PyObject* o, const char* attr_name, + std::function f); + +bool HandleRepeatedInt64Attribute( + PyObject* o, const char* attr_name, + tensorflow::protobuf::RepeatedField* field); + } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds b/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds new file mode 100644 index 0000000000000000000000000000000000000000..ef77ed3d95850fdfc7145e6fe1df4833d20bb7df --- /dev/null +++ b/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds @@ -0,0 +1,2 @@ +_PyInit__pywrap_xla +_init_pywrap_xla diff --git a/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds b/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds new file mode 100644 index 0000000000000000000000000000000000000000..d31cfce7be7b6accf05ef77f3485904099965afc --- /dev/null +++ b/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds @@ -0,0 +1,6 @@ +xla { + global: + PyInit_*; + local: + *; +}; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index c91a2aaf56dfe2127168628c78e0c4b868a28055..9019a979a61c6ebb62adaa5503560c604e2b30f8 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""An in-process, local XLA client in Python, supporting AOT compilation.""" +"""An XLA client in Python, supporting AOT compilation.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import enum # pylint: disable=g-bad-import-order import inspect @@ -33,13 +34,32 @@ from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api from tensorflow.compiler.xla.service import hlo_pb2 +# Import the XRT backend, if available. +try: + # pylint: disable=g-import-not-at-top + from tensorflow.compiler.xla.python import pywrap_xrt as xrt_api +except ImportError: + xrt_api = None + # Most functions are snake_case for consistency with other modules, whereas -# method names of ComputationBuilder and LocalComputation are CamelCase for +# method names of ComputationBuilder and Computation are CamelCase for # consistency with XLA. # pylint: disable=invalid-name +# Version of the XLA Python client. +# +# JAX packages the XLA python plugin as a binary pip module (jaxlib) that is +# packaged separately from the Python code that consumes it (jax). +# +# We occasionally need to make backwards-incompatible changes to jaxlib, in +# which case we need to be able to detect when incompatible versions are +# installed. +def version(): + return (0, 1, 8) + + _OP_METADATA_FIELDS = [ 'op_type', 'op_name', @@ -49,13 +69,163 @@ _OP_METADATA_FIELDS = [ OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS) +@six.add_metaclass(abc.ABCMeta) +class Backend(object): + """Abstract base class for XLA backends.""" + + @abc.abstractmethod + def device_count(self): + """Returns the number of devices known to the backend.""" + + @abc.abstractmethod + def buffer_from_pyval(self, pyval, device=0): + """Allocates a fresh buffer and populates it with `pyval`.""" + + @abc.abstractmethod + def delete_buffer(self, c_buffer): + """Deletes buffer `c_buffer`.""" + + @abc.abstractmethod + def destructure_tuple(self, c_buffer): + """Destructures a tuple buffer into a sequence of buffers.""" + + @abc.abstractmethod + def compile(self, computation, argument_shapes, result_shape, + compile_options): + """Compiles a computation. Returns an executable.""" + + @abc.abstractmethod + def delete_executable(self, executable): + """Deletes an executable.""" + + @abc.abstractmethod + def execute(self, executable, args): + """Runs an executable without replication.""" + + @abc.abstractmethod + def execute_replicated(self, executable, per_replica_args): + """Runs an executable in a replicated manner.""" + + +def _maybe_encode_string(s): + if six.PY3: + return s.encode('utf-8') + else: + return s + + +class XlaLocalBackend(Backend): + """XLA backend implemented using the in-process xla::LocalClient API.""" + + def __init__(self, platform=None): + platform = platform or _get_default_platform_name() + self.client = c_api.LocalClient.Get(_maybe_encode_string(platform)) + self._delete_buffer = c_api.DeleteLocalShapedBuffer + self._delete_executable = c_api.DeleteLocalExecutable + + def device_count(self): + return self.client.DeviceCount() + + def buffer_from_pyval(self, pyval, device=0): + return c_api.LocalShapedBuffer.FromLiteral(pyval, None, self.client, device) + + def delete_buffer(self, c_buffer): + self._delete_buffer(c_buffer) + + def destructure_tuple(self, c_buffer): + result = c_buffer.DestructureTuple() + return [result.Release(i) for i in xrange(result.size())] + + def compile(self, c_computation, argument_shapes, result_shape, + compile_options): + return c_computation.Compile(argument_shapes, compile_options, self.client) + + def delete_executable(self, executable): + self._delete_executable(executable) + + def execute(self, executable, args): + return executable.Execute(args) + + def execute_replicated(self, executable, per_replica_args): + output_buffer_tup = executable.ExecutePerReplica(per_replica_args) + size = output_buffer_tup.size() + return [output_buffer_tup.Release(i) for i in xrange(size)] + + +class XrtBackend(Backend): + """XLA backend implemented using XRT.""" + + def __init__(self, target): + self.target = target + self._delete_buffer = xrt_api.DeleteXrtAllocation + self._delete_executable = xrt_api.DeleteXrtExecutable + + def device_count(self): + return 1 # Multidevice execution not implemented. + + def buffer_from_pyval(self, pyval, device=0): + if device != 0: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') + return xrt_api.XrtAllocation.FromLiteral(pyval, + _maybe_encode_string(self.target)) + + def delete_buffer(self, c_buffer): + self._delete_buffer(c_buffer) + + def destructure_tuple(self, c_buffer): + result = xrt_api.DestructureXrtAllocationTuple( + c_buffer, _maybe_encode_string(self.target)) + return [result.Release(i) for i in xrange(result.size())] + + def compile(self, c_computation, argument_shapes, result_shape, + compile_options): + return xrt_api.XrtExecutable.CompileForXrt( + c_computation.GetSerializedProto(), argument_shapes, result_shape, + _maybe_encode_string(self.target)) + + def delete_executable(self, executable): + self._delete_executable(executable) + + def execute(self, executable, args): + return executable.Execute(args) + + def execute_replicated(self, executable, per_replica_args): + if len(per_replica_args) != 1: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') + return [executable.Execute(per_replica_args[0])] + + +_default_platform_name = 'Host' +_default_backend = None + + +def _get_default_platform_name(): + return _default_platform_name + + +def _get_default_local_backend(): + global _default_backend + global _default_platform_name + if _default_backend is None: + _default_backend = XlaLocalBackend(_default_platform_name) + return _default_backend + + class BackendType(enum.Enum): XLA_LOCAL = 1 XRT = 2 -BackendSpec = collections.namedtuple('Backend', ('backend_type', 'target')) -XLA_LOCAL_BACKEND = BackendSpec(BackendType.XLA_LOCAL, 'local') +def BackendSpec(backend, target): + """Compatibility wrapper to support older clients. Do not use in new code.""" + if backend == BackendType.XLA_LOCAL: + return _get_default_local_backend() + elif backend == BackendType.XRT: + return XrtBackend(target) + else: + raise ValueError('Unknown backend {}'.format(backend)) def OpMetadataToProto(pyobj): @@ -78,13 +248,6 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): source_line=lineno) -def _maybe_encode_string(s): - if six.PY3: - return s.encode('utf-8') - else: - return s - - class PaddingType(enum.Enum): VALID = 1 SAME = 2 @@ -122,6 +285,7 @@ def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, _UNARY_OPS = [ 'Not', + 'Clz', 'Abs', 'Exp', 'Expm1', @@ -199,6 +363,7 @@ XLA_ELEMENT_TYPE_TO_DTYPE = { xla_data_pb2.F32: np.dtype('float32'), xla_data_pb2.F64: np.dtype('float64'), xla_data_pb2.C64: np.dtype('complex64'), + xla_data_pb2.C128: np.dtype('complex128'), xla_data_pb2.TUPLE: np.dtype(np.object), } @@ -222,33 +387,18 @@ class LocalBuffer(object): means the referent is in device memory. """ - def __init__(self, c_buffer, backend, replica): + def __init__(self, c_buffer, backend, device): self.c_buffer = c_buffer self._backend = backend - self._replica = replica - if backend.backend_type == BackendType.XRT: - self._delete = c_api.DeleteXrtAllocation - else: - self._delete = c_api.DeleteLocalShapedBuffer + self._device = device @staticmethod - def from_pyval(pyval, replica=0, backend=XLA_LOCAL_BACKEND): + def from_pyval(pyval, device=0, backend=None): """Allocate and copy to XLA the given python value.""" + backend = backend or _get_default_local_backend() pyval = require_numpy_array_layout(pyval) - num_replicas = get_replica_count() - if not 0 <= replica < num_replicas: - raise ValueError( - 'Attempt to place buffer on replica {} when the replica count is {}' - .format(replica, num_replicas)) - if backend.backend_type == BackendType.XRT: - if replica != 0: - raise NotImplementedError( - 'Multi-replica execution is not yet supported via the XRT backend.') - cbuf = c_api.XrtAllocation.FromLiteral( - pyval, _maybe_encode_string(backend.target)) - else: - cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None, replica) - return LocalBuffer(cbuf, backend, replica) + cbuf = backend.buffer_from_pyval(pyval, device) + return LocalBuffer(cbuf, backend, device) def to_py(self): return self.c_buffer.ToLiteral() @@ -256,29 +406,22 @@ class LocalBuffer(object): def shape(self): return _wrap_shape(self.c_buffer.shape()) - def replica(self): - return self._replica + def device(self): + return self._device def delete(self): if self.c_buffer is not None: - self._delete(self.c_buffer) + self._backend.delete_buffer(self.c_buffer) self.c_buffer = None def destructure(self): """Assuming a tuple buffer, unpack it into constituent tuple elements.""" assert self.c_buffer is not None - if self._backend.backend_type == BackendType.XRT: - result = c_api.DestructureXrtAllocationTuple( - self.c_buffer, _maybe_encode_string(self._backend.target)) - else: - result = c_api.DestructureLocalShapedBufferTuple(self.c_buffer) + result = self._backend.destructure_tuple(self.c_buffer) self.delete() - size = result.size() - destructured = tuple( - LocalBuffer( - result.Release(i), replica=self._replica, backend=self._backend) - for i in xrange(size)) - return destructured + return tuple( + LocalBuffer(sub_buffer, device=self._device, backend=self._backend) + for sub_buffer in result) def is_deleted(self): return self.c_buffer is None @@ -415,7 +558,7 @@ class Shape(object): assert mtm is None, self if mtm is not None: assert self.rank() == len(mtm), self - assert sorted(mtm) == range(len(mtm)), self + assert sorted(mtm) == list(range(len(mtm))), self def update_minor_to_major(self, minor_to_major): if not self.is_array(): @@ -427,6 +570,34 @@ class Shape(object): updated._check_minor_to_major() # pylint: disable=protected-access return updated + def with_major_to_minor_layout_if_absent(self): + """Returns a copy of a shape with missing layouts set to major-to-minor.""" + + def f(a): + if a.minor_to_major(): + return None + return a.update_minor_to_major(tuple(xrange(a.rank() - 1, -1, -1))) + + return self.map_leaves(f) + + def serialize(self, proto): + """Serializes 'shape' into proto.""" + if self.is_tuple(): + proto.element_type = xla_data_pb2.TUPLE + for shape in self.tuple_shapes(): + shape.serialize(proto.tuple_shapes.add()) + else: + proto.element_type = dtype_to_etype(self.element_type()) + proto.dimensions.extend(self.dimensions()) + proto.is_dynamic_dimension.extend([False for _ in self.dimensions()]) + if self.minor_to_major(): + proto.layout.format = xla_data_pb2.DENSE + proto.layout.minor_to_major.extend(self.minor_to_major()) + + +ProgramShape = collections.namedtuple('ProgramShape', + ('parameter_shapes', 'result_shape')) + def _wrap_shape(shape_info): dtype, dims = shape_info @@ -438,6 +609,12 @@ def _wrap_shape(shape_info): return Shape.array_shape(dtype, dims) +def _wrap_program_shape(shape_info): + arg_shapes, result_shape = shape_info + return ProgramShape([_wrap_shape(arg) for arg in arg_shapes], + _wrap_shape(result_shape)) + + def require_numpy_array_layout(value): if isinstance(value, tuple): return tuple(require_numpy_array_layout(x) for x in value) @@ -458,9 +635,10 @@ class CompileOptions(object): self.dump_unoptimized_hlo_proto_to = None self.dump_per_pass_hlo_proto_to = None self.hlo_profile = False + self.num_replicas = get_replica_count() -def transfer_to_infeed(value, replica_number=None): +def transfer_to_infeed(value, device_ordinal=0): """Transfers the given value into the XLA infeed queue. XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with @@ -470,64 +648,50 @@ def transfer_to_infeed(value, replica_number=None): Args: value: the value that the caller would like to enqueue into the XLA infeed queue - replica_number: the replica number to infeed the value to -- if not - provided, then the default replica (trivially replica 0) is used. + device_ordinal: the device to infeed the value to. Each device has a + distinct infeed queue. """ - if replica_number is None: - c_api.TransferToInfeedLocal(require_numpy_array_layout(value)) - else: - c_api.TransferToInfeedLocalReplica( - require_numpy_array_layout(value), replica_number) + # TODO(phawkins): support non-default backends. + backend = _get_default_local_backend() + backend.client.TransferToInfeed( + require_numpy_array_layout(value), device_ordinal) -def transfer_from_outfeed(shape, replica_number=None): - """Transfers a literal of the given shape from replica_number's outfeed. +def transfer_from_outfeed(shape, device_ordinal=0): + """Transfers a literal of the given shape from `device_ordinal`'s outfeed. Args: shape: The shape of the value to transfer from outfeed. - replica_number: The replica number ordinal to transfer the outfeed value - from. (Each replica has a distinct outfeed queue.) + device_ordinal: The device ordinal to transfer the outfeed value from. Each + device has a distinct outfeed queue.. Returns: The literal value that is produced from the outfeed queue. """ - return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0) + # TODO(phawkins): support non-default backends. + backend = _get_default_local_backend() + return backend.client.TransferFromOutfeed(shape, device_ordinal) -class LocalComputation(object): - """Python wrapper for a local XLA Computation. +class Computation(object): + """Python wrapper for an XLA Computation. - A LocalComputation can be executed if it is compiled. Otherwise, it - can still be used as a Computation where required by the - ComputationBuilder methods. + A Computation can be compiled to form an Executable, or used as a + subcomputation in ComputationBuilder methods. """ - def __init__(self, c_computation, is_compiled, backend=XLA_LOCAL_BACKEND): + def __init__(self, c_computation, backend=None): self._c_computation = c_computation + # The backend argument is deprecated. Pass a backend to Compile() instead. self._backend = backend - self._is_compiled = is_compiled - - # Ensure a reference to C-based destructor for use in __del__. - if is_compiled: - if backend.backend_type == BackendType.XRT: - assert isinstance(c_computation, c_api.CompiledXrtComputation) - self._delete = c_api.DeleteCompiledXrtComputation - else: - assert isinstance(c_computation, c_api.CompiledLocalComputation) - self._delete = c_api.DeleteCompiledLocalComputation - else: - assert isinstance(c_computation, c_api.LocalComputation) - self._delete = c_api.DeleteLocalComputation + self._delete_computation = c_api.DeleteComputation @property def computation(self): - if self._is_compiled: - raise ValueError( - 'Attempt to read the XLA computation of a compiled LocalComputation.') return self._c_computation def GetProto(self): - """Get the HloModuleProto proto object in this local computation. + """Get the HloModuleProto proto object in this computation. Returns: An HloModuleProto proto object that has the whole-graph information. @@ -536,30 +700,41 @@ class LocalComputation(object): proto = hlo_pb2.HloModuleProto.FromString(serialized) return proto - def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): - """Compiles an un-compiled local computation. + def GetHloText(self): + """Get the textual HLO representation of this computation. + + Returns: + A string containing the textual HLO. + """ + return self.computation.GetHloText() + + def GetHloDotGraph(self): + """Get a Graphviz Dot representation of this computation. + + Returns: + A string containing the graphviz dot graph. + """ + return self.computation.GetHloDotGraph() - Local computations are the result of a "LocalComputationBuild'ing" process - -- they start in uncompiled form, and via a call to Compile() turn into a - compiled local computation. + def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None, + backend=None): + """Compiles a computation. - Raises: - ValueError: if this is already a compiled local computation. + Computations are the result of a "ComputationBuild'ing" process. Arguments: argument_shapes: parameter shapes -- they are first laid out by layout_fn if layout_fn is provided. Otherwise, the default layout for those shapes will be used. - compile_options: options to use for compilation, includes an optional - laid out result shape for the computation. + compile_options: options to use for compilation, includes an optional laid + out result shape for the computation. layout_fn: lambda that is used to lay out the argument/result shapes. + backend: a `Backend` for which an executable should be generated. Returns: - A newly *compiled* local computation instance. + A Executable instance. """ - if self._is_compiled: - raise ValueError('Attempt to compile a compiled local XLA computation.') - + backend = backend or self._backend or _get_default_local_backend() result_shape = _wrap_shape(self.computation.GetReturnValueShape()) if layout_fn: @@ -572,32 +747,52 @@ class LocalComputation(object): compile_options = compile_options or CompileOptions() compile_options.result_shape = result_shape - if self._backend.backend_type == BackendType.XRT: - c = self.computation.CompileForXrt( - argument_shapes, _maybe_encode_string(self._backend.target)) - else: - c = self.computation.Compile(argument_shapes, compile_options) - return LocalComputation(c, is_compiled=True, backend=self._backend) + c = backend.compile(self.computation, argument_shapes, result_shape, + compile_options) + return Executable(c, backend=backend) def CompileWithExampleArguments(self, arguments=(), compile_options=None, - layout_fn=None): + layout_fn=None, + backend=None): return self.Compile( argument_shapes=[Shape.from_pyval(arg) for arg in arguments], compile_options=compile_options, - layout_fn=layout_fn) + layout_fn=layout_fn, + backend=backend) + + def GetProgramShape(self): + return _wrap_program_shape(self._c_computation.GetProgramShape()) def GetReturnValueShape(self): return _wrap_shape(self._c_computation.GetReturnValueShape()) + def __del__(self): + if self._c_computation: + self._delete_computation(self._c_computation) + + +class Executable(object): + """Python wrapper for an XLA Executable.""" + + def __init__(self, c_executable, backend=None): + self._c_executable = c_executable + self._device_ordinals = c_executable.DeviceOrdinals() + self._backend = backend + + def DeviceOrdinals(self): + """Returns a list containing the device ordinals for each replica.""" + return self._device_ordinals + def Execute(self, arguments=(), check_for_deleted_args=True): """Execute on one replica with LocalBuffer arguments and return value.""" if check_for_deleted_args and any(arg.is_deleted() for arg in arguments): raise ValueError('Executing with deleted local buffer argument') raw_args = [arg.c_buffer for arg in arguments] - output_buffer = self._c_computation.Execute(raw_args) - return LocalBuffer(output_buffer, backend=self._backend, replica=0) + output_buffer = self._backend.execute(self._c_executable, raw_args) + return LocalBuffer( + output_buffer, backend=self._backend, device=self._device_ordinals[0]) def ExecutePerReplica(self, arguments=None): """Execute on many replicas with LocalBuffer arguments and return value. @@ -607,14 +802,12 @@ class LocalComputation(object): sequence comprises the arguments for execution on the i'th replica. Returns: - A list of the computation's outputs on each replica, as a LocalBuffer. If + A list of the computation's outputs for each replica, as a LocalBuffer. If a shallow sequence of arguments was passed in for `arguments`, then the sole, zero'th replica's output is returned instead, as a LocalBuffer. """ - if not self._is_compiled: - raise ValueError('Cannot execute an uncompiled local XLA computation.') if arguments is None: - arguments = ((),) * get_replica_count() + arguments = ((),) * len(self._device_ordinals) else: arguments = [list(replica_args) for replica_args in arguments] @@ -623,37 +816,35 @@ class LocalComputation(object): for arg in replica_args: if arg.is_deleted(): raise ValueError('Executing with deleted local buffer argument') - if arg.replica() != replica: + if arg.device() != self._device_ordinals[replica]: raise ValueError( - 'Executing on replica {} with argument from replica {}'.format( - replica, arg.replica())) + 'Executing on device {} with argument from device {}'.format( + self._device_ordinals[replica], arg.device())) # Pull out argument buffer handles + # pylint: disable=g-complex-comprehension stripped_args = [ [arg.c_buffer for arg in replica_args] for replica_args in arguments ] # Execute - if self._backend.backend_type == BackendType.XRT: - if len(stripped_args) > 1: - raise NotImplementedError( - 'Multi-replica execution is not yet supported via the XRT backend.') - output_buffers = [self._c_computation.Execute(stripped_args[0])] - else: - output_buffer_tup = self._c_computation.ExecutePerReplica(stripped_args) - size = output_buffer_tup.size() - output_buffers = [output_buffer_tup.Release(i) for i in xrange(size)] + output_buffers = self._backend.execute_replicated(self._c_executable, + stripped_args) # Wrap output handles in LocalBuffer instances return tuple( - LocalBuffer(output_buffer, backend=self._backend, replica=replica) + LocalBuffer( + output_buffer, + backend=self._backend, + device=self._device_ordinals[replica]) for replica, output_buffer in enumerate(output_buffers)) def ExecuteWithPythonValues(self, arguments=()): """Execute on one replica with Python values as arguments and output.""" def put(arg): - return LocalBuffer.from_pyval(arg, backend=self._backend) + return LocalBuffer.from_pyval( + arg, device=self._device_ordinals[0], backend=self._backend) arguments = [put(arg) for arg in arguments] return self.Execute(arguments).to_py() @@ -661,24 +852,33 @@ class LocalComputation(object): def ExecuteWithPythonValuesPerReplica(self, arguments): """Execute on many replicas with Python values as arguments and output.""" - def put(arg, replica): - return LocalBuffer.from_pyval(arg, replica, backend=self._backend) + def put(arg, device): + return LocalBuffer.from_pyval(arg, device, backend=self._backend) - arguments = [[put(arg, replica) - for arg in replica_args] - for replica, replica_args in enumerate(arguments)] + # pylint: disable=g-complex-comprehension + arguments = [[ + put(arg, self._device_ordinals[replica]) for arg in replica_args + ] for replica, replica_args in enumerate(arguments)] return [out.to_py() for out in self.ExecutePerReplica(arguments)] def __del__(self): - self._delete(self._c_computation) + # Python may have freed c_api first. + if c_api and self._c_executable: + self._backend.delete_executable(self._c_executable) + + +def _make_replica_group_proto(replica_group): + replica_group_proto = xla_data_pb2.ReplicaGroup() + replica_group_proto.replica_ids.extend(replica_group) + return replica_group_proto class ComputationBuilder(object): """XLA computation builder. Enqueues XLA ops in sequence and in order to build a - LocalComputation, which in turn can be compiled into a - CompiledLocalComputation, which in turn can be locally executed. + Computation, which in turn can be compiled into a + LocalExecutable, which in turn can be locally executed. """ # The methods of this class map 1-to-1 onto the XLA C++ @@ -689,16 +889,23 @@ class ComputationBuilder(object): # pylint: disable=g-doc-args def __init__(self, name): - self._client = c_api.LocalComputationBuilder(name.encode('utf8')) + self._client = c_api.ComputationBuilder(name.encode('utf8')) self._parameter_numbering = itertools.count() - def Build(self, root=None, backend=XLA_LOCAL_BACKEND): + def Build(self, root=None, backend=None): + """Builds a `Computation` from the contents of the builder. + + Args: + root: if not None, the operator containing the return value of the + computation. + backend: deprecated. Pass a `backend` to `Computation.Compile` instead. + Returns: + A `Computation`. + """ if root is not None: - return LocalComputation( - self._client.BuildWithRoot(root), is_compiled=False, backend=backend) + return Computation(self._client.BuildWithRoot(root), backend=backend) else: - return LocalComputation( - self._client.Build(), is_compiled=False, backend=backend) + return Computation(self._client.Build(), backend=backend) def SetOpMetadata(self, op_metadata): """Set metadata for operations that are about to be enqueued.""" @@ -831,6 +1038,33 @@ class ComputationBuilder(object): return self.ParameterWithShape( Shape.from_pyval(value), name=name, parameter_num=parameter_num) + def Iota(self, dtype, size): + """Enqueues an iota constant onto the computation. + + Args: + dtype: expected numpy dtype of the output. + size: integer, the number of elements in the array. + + Returns: + A LocalOp representing the added iota constant. + """ + element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] + return self._client.Iota(element_type, size) + + def BroadcastedIota(self, dtype, shape, dimension): + """Enqueues a broadcasted iota constant onto the computation. + + Args: + dtype: expected numpy dtype of the output. + shape: tuple of integers, the expected output shape (dimensions). + dimension: positive integer, dimension along which to increment values. + + Returns: + A LocalOp representing the added broadcasted iota constant. + """ + xla_shape = Shape.array_shape(dtype, shape) + return self._client.BroadcastedIota(xla_shape, dimension) + def Broadcast(self, operand, sizes): """Enqueues a broadcast operation onto the computation. @@ -936,16 +1170,60 @@ class ComputationBuilder(object): dimensions = tuple(range(ndim)) return self._client.Reshape(operand, dimensions, new_sizes) - def CrossReplicaSum(self, operand): + def AllToAll(self, + operand, + split_dimension, + concat_dimension, + replica_groups=None): + """AllToAll op. + + Args: + operand: LocalOp representing the input array + split_dimension: the dimension along which the operand is split + concat_dimension: the dimension along which the split blocks are + concatenated + replica_groups: optional, list of lists of ints encoding a partition of + the set {0, 1, ..., num_replicas} into equally-sized replica groups + within which the all-to-all is performed. If not supplied or None (the + default), all replicas belong to the same group. + + Returns: + A LocalOp that represents the all-to-all concatenation. + """ + if replica_groups is None: + replica_groups_protos = [] # special value for XLA API + else: + replica_groups = list(replica_groups) + replica_groups_protos = [ + _make_replica_group_proto(group) for group in replica_groups] + if not replica_groups: + split_count = get_replica_count() + else: + split_count = len(replica_groups[0]) + if not all(split_count == len(g) for g in replica_groups): + raise ValueError('Replica groups must be equally sized') + return self._client.AllToAll(operand, split_dimension, concat_dimension, + split_count, replica_groups_protos) + + def CrossReplicaSum(self, operand, replica_groups=None): """CrossReplicaSum op. Args: operand: the operand to sum across replica instances. + replica_groups: optional, list of lists of ints encoding a partition of + the set {0, 1, ..., num_replicas} into equally-sized replica groups + within which the cross-replica sum is performed. If not supplied or None + (the default), all replicas belong to the same group. Returns: - A LocalOp that has the sum of the value among all replicas. + A LocalOp that represents on each replica the sum of its group's values. """ - return self._client.CrossReplicaSum(operand) + if replica_groups is None: + replica_groups = [] # special value for XLA API + else: + replica_groups = [ + _make_replica_group_proto(group) for group in replica_groups] + return self._client.CrossReplicaSum(operand, replica_groups) def Collapse(self, operand, dimensions): """Collapse op.""" @@ -1102,6 +1380,31 @@ class ComputationBuilder(object): """ return self._client.Call(computation_to_apply.computation, operands) + def CustomCall(self, + call_target_name, + operands, + shape_with_layout, + operand_shapes_with_layout, + opaque=None): + """Enqueues a custom call operation onto the computation. + + Args: + call_target_name: the name of the function to call. + operands: an iterable of LocalOp. The number and types of operands must + match the arity of `operand_shapes_with_layout`. + shape_with_layout: the shape of the operator's output, with layout. + operand_shapes_with_layout: the shapes of `operands`, including the + expected layouts. + opaque: an opaque string passed to the backend. + + Returns: + A LocalOp representing the added custom call op. + """ + opaque = opaque or b'' + return self._client.CustomCall(call_target_name, operands, + shape_with_layout, + operand_shapes_with_layout, opaque) + def Map(self, operands, computation_to_apply, dimensions): """Enqueues a map operation onto the computation. @@ -1254,7 +1557,7 @@ class ComputationBuilder(object): Args: operand: a LocalOp to test. - Returns: a LocalComputation that is rooted on the given `operand` which is a + Returns: a Computation that is rooted on the given `operand` which is a compile-time constant. """ return self._client.BuildConstantSubGraph(operand) @@ -1411,13 +1714,51 @@ class ComputationBuilder(object): """Enqueues a key-value sort operation onto the computation.""" return self._client.SortKeyVal(keys, values, dimension) + def Cholesky(self, a): + """Enqueues a Cholesky decomposition onto the computation.""" + return self._client.Cholesky(a) + + def QR(self, a, full_matrices=True): + """Enqueues a QR decomposition onto the computation.""" + return self._client.QR(a, full_matrices) + + def TriangularSolve(self, + a, + b, + left_side=False, + lower=False, + transpose_a=False, + conjugate_a=False, + unit_diagonal=False): + """Enqueues a triangular-solve operation onto the computation.""" + if not transpose_a: + transpose = 1 + if conjugate_a: + a = self.Conj(a) + else: + transpose = 3 if conjugate_a else 2 + return self._client.TriangularSolve(a, b, left_side, lower, unit_diagonal, + transpose) + + def Gather(self, a, start_indices, dimension_numbers, slice_sizes): + """Enqueues a Gather operation onto the computation.""" + return self._client.Gather(a, start_indices, dimension_numbers, + slice_sizes) + + def Scatter(self, a, scatter_indices, updates, update_computation, + dimension_numbers): + """Enqueues a Scatter operation onto the computation.""" + return self._client.Scatter( + a, scatter_indices, updates, update_computation.computation, + dimension_numbers,) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. Set up methods, corresponding to unary and binary XLA operations, whose calls are forwarded in a boilerplate manner to the underlying - LocalComputationBuilder C-extension API. + ComputationBuilder C-extension API. """ def forward_to_local_builder_with_handles(target_method, is_binop=False): @@ -1437,13 +1778,13 @@ def _forward_methods_to_local_builder(): for method_name in _UNARY_OPS: forward = forward_to_local_builder_with_handles( - getattr(c_api.LocalComputationBuilder, method_name)) + getattr(c_api.ComputationBuilder, method_name)) forward.__name__ = method_name setattr(ComputationBuilder, method_name, forward) for method_name in _BINARY_OPS: forward = forward_to_local_builder_with_handles( - getattr(c_api.LocalComputationBuilder, method_name), is_binop=True) + getattr(c_api.ComputationBuilder, method_name), is_binop=True) forward.__name__ = method_name setattr(ComputationBuilder, method_name, forward) @@ -1451,8 +1792,14 @@ def _forward_methods_to_local_builder(): _forward_methods_to_local_builder() +_default_replica_count = 1 + + def initialize_replica_count(replica_count): - """Initializes the desired replica count to use on XLA service init. + """Initializes the default replica count to use. + + Deprecated; pass `num_replicas` as an option to `Computation.Compile()` + instead. Args: replica_count: number of replicas that are desired for set up during XLA @@ -1461,29 +1808,40 @@ def initialize_replica_count(replica_count): Raises: A runtime exception if the XLA service has already been initialized. """ - c_api.InitializeReplicaCount(replica_count) + global _default_replica_count + _default_replica_count = replica_count + + +def get_replica_count(): + """Returns the default replica count. + + Deprecated; pass `num_replicas` as an option to `Computation.Compile()` + instead. + """ + return _default_replica_count def initialize_platform_name(platform_name): - """Initializes the desired platform name to use on XLA service init. + """Initializes the default platform name to use for XLA. Args: platform_name: string name of platform. - - Raises: - A runtime exception if the XLA service has already been initialized. """ - platform_name = _maybe_encode_string(platform_name) - c_api.InitializePlatformName(platform_name) + global _default_platform_name + _default_platform_name = platform_name + # Make sure the platform is valid by trying to instantiate it. + _get_default_local_backend() -def get_replica_count(): - """Returns the current replica count used for the XLA service. - Note: this will return a value whether the XLA service has been initialized - yet or not. +def register_cpu_custom_call_target(name, fn): + """Registers a CPU custom call target. + + Args: + name: bytes containing the name of the function. + fn: a PyCapsule object containing the function pointer. """ - return c_api.GetReplicaCount() + c_api.RegisterCpuCustomCallTarget(name, fn) def GetPaddingConfigFromTriples(triples): diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 21b5c93b615ec429a5da0b4ffe89e8f75f59ef1b..51ef7d7f3a17f341e955f48615b05a886813430b 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -18,16 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import itertools import threading import numpy as np +from tensorflow.compiler.xla.python import custom_call_for_test from tensorflow.compiler.xla.python import xla_client import unittest -class LocalComputationTest(unittest.TestCase): +class ComputationTest(unittest.TestCase): """Base class for running an XLA Computation through the local client.""" def _NewComputation(self, name=None): @@ -51,9 +53,11 @@ class LocalComputationTest(unittest.TestCase): def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) - def _ExecuteAndCompareClose(self, c, arguments=(), expected=None): - self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments, - expected) + def _ExecuteAndCompareClose(self, c, arguments=(), expected=None, rtol=1e-7, + atol=0): + self._ExecuteAndAssertWith( + functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), + c, arguments, expected) def NumpyArrayF32(*args, **kwargs): @@ -81,9 +85,35 @@ def NumpyArrayBool(*args, **kwargs): return np.array(*args, dtype=np.bool, **kwargs) -class ComputationsWithConstantsTest(LocalComputationTest): +class ComputationPrinting(unittest.TestCase): + + def ExampleComputation(self): + builder = xla_client.ComputationBuilder("acomputation") + p0 = builder.ParameterFromNumpy(np.float32(0)) + p1 = builder.ParameterFromNumpy(np.zeros((4,), np.float32)) + builder.Mul(p0, p1) + return builder.Build() + + def testComputationToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.GetHloText() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testComputationToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = computation.GetHloDotGraph() + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + +class ComputationsWithConstantsTest(ComputationTest): """Tests focusing on Constant ops.""" + def testConstantScalarSumS8(self): + c = self._NewComputation() + root = c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2))) + self.assertEqual(c.GetShape(root), c.GetReturnValueShape()) + self._ExecuteAndCompareExact(c, expected=np.int8(3)) + def testConstantScalarSumF32(self): c = self._NewComputation() root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) @@ -143,6 +173,17 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + def testIota(self): + c = self._NewComputation() + c.Iota(np.float32, 10) + self._ExecuteAndCompareExact(c, expected=np.arange(10, dtype=np.float32)) + + def testBroadcastedIota(self): + c = self._NewComputation() + c.BroadcastedIota(np.int64, (2, 3), 1) + expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=np.int64) + self._ExecuteAndCompareExact(c, expected=expected) + def testBooleanAnd(self): c = self._NewComputation() c.And( @@ -268,8 +309,22 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayF64([100, -100, 200, -200]))) self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + def testCustomCall(self): + c = self._NewComputation() + for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): + xla_client.register_cpu_custom_call_target(name, fn) + c.CustomCall( + b"test_subtract_f32", + operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)), + shape_with_layout=xla_client.Shape.array_shape(np.float32, (), ()), + operand_shapes_with_layout=( + xla_client.Shape.array_shape(np.float32, (), ()), + xla_client.Shape.array_shape(np.float32, (), ()), + )) + self._ExecuteAndCompareClose(c, expected=0.75) + -class ParametersTest(LocalComputationTest): +class ParametersTest(ComputationTest): """Tests focusing on Parameter ops and argument-passing.""" def setUp(self): @@ -349,7 +404,7 @@ class ParametersTest(LocalComputationTest): expected=[-4.3, 1.3, -6.3, 3.3]) -class LocalBufferTest(LocalComputationTest): +class LocalBufferTest(ComputationTest): """Tests focusing on execution with LocalBuffers.""" def _Execute(self, c, arguments): @@ -447,7 +502,7 @@ class LocalBufferTest(LocalComputationTest): self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) -class SingleOpTest(LocalComputationTest): +class SingleOpTest(ComputationTest): """Tests for single ops. The goal here is smoke testing - to exercise the most basic functionality of @@ -524,6 +579,18 @@ class SingleOpTest(LocalComputationTest): for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): _ConvertAndTest(x, src_dtype, dst_dtype, xla_types[dst_dtype]) + # TODO(b/123523486): re-enable when shape check is resolved + def DISABLED_testAllToAllOneReplica(self): + samples = [ + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples[:1]: + c = self._NewComputation() + c.AllToAll(c.Constant(lhs), 0, 0) + self._ExecuteAndCompareExact(c, expected=lhs) + def testCrossReplicaSumOneReplica(self): samples = [ NumpyArrayF32(42.0), @@ -536,6 +603,18 @@ class SingleOpTest(LocalComputationTest): c.CrossReplicaSum(c.Constant(lhs)) self._ExecuteAndCompareExact(c, expected=lhs) + def testCrossReplicaSumOneReplicaWithSingletonGroup(self): + samples = [ + NumpyArrayF32(42.0), + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples: + c = self._NewComputation() + c.CrossReplicaSum(c.Constant(lhs), [[0]]) + self._ExecuteAndCompareExact(c, expected=lhs) + def testDotMatrixVectorF32(self): c = self._NewComputation() lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) @@ -698,6 +777,12 @@ class SingleOpTest(LocalComputationTest): c.Not(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=~arr) + def testCountLeadingZeros(self): + c = self._NewComputation() + arr = NumpyArrayS32([0x7FFF, 0x12345678]) + c.Clz(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=[17, 3]) + def testExp(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) @@ -1057,6 +1142,38 @@ class SingleOpTest(LocalComputationTest): self.assertTrue(np.all(lo <= result)) self.assertTrue(np.all(result < hi)) + def testCholesky(self): + l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], + dtype=np.float32) + c = self._NewComputation() + c.Cholesky(c.Constant(np.dot(l, l.T))) + self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4) + + def testQR(self): + a = np.array( + [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + c.QR(c.Constant(a), full_matrices=True) + q, r = self._Execute(c, ()) + np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) + + def testTriangularSolve(self): + a_vals = np.array( + [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], + dtype=np.float32) + b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + dtype=np.float32) + + c = self._NewComputation() + c.TriangularSolve(c.Constant(a_vals), c.Constant(b_vals), left_side=False, + lower=True, transpose_a=True) + self._ExecuteAndCompareClose(c, expected=np.array([ + [0.5, 0.08333334, 0.04629629, 0.03367003], + [2.5, -0.25, -0.1388889, -0.1010101], + [4.5, -0.58333331, -0.32407406, -0.23569024], + ], dtype=np.float32), rtol=1e-4) + def testIsConstant(self): c = self._NewComputation() a = c.ConstantS32Scalar(3) @@ -1068,8 +1185,23 @@ class SingleOpTest(LocalComputationTest): self.assertFalse(c.IsConstant(non_const_expr)) # self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x))) # TODO(b/77245564) + def testGather(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) + dnums = xla_client.xla_data_pb2.GatherDimensionNumbers() + dnums.offset_dims.append(1) + dnums.offset_dims.append(2) + dnums.start_index_map.append(0) + dnums.start_index_map.append(1) + dnums.index_vector_dim = 2 + c = self._NewComputation() + c.Gather(c.Constant(a), c.Constant(indices), dnums, slice_sizes=[1, 1]) + g = self._Execute(c, ()) + expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) + np.testing.assert_allclose(g, expected, rtol=1e-4) -class EmbeddedComputationsTest(LocalComputationTest): + +class EmbeddedComputationsTest(ComputationTest): """Tests for XLA graphs with embedded computations (such as maps).""" def _CreateConstantS32Computation(self): @@ -1125,6 +1257,14 @@ class EmbeddedComputationsTest(LocalComputationTest): c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0)) return c.Build() + def _CreateBinaryAddS32Computation(self): + """Computation (s32, s32) -> s32 that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + c.Add( + c.ParameterFromNumpy(NumpyArrayS32(0)), + c.ParameterFromNumpy(NumpyArrayS32(0))) + return c.Build() + def _CreateBinaryAddF32Computation(self): """Computation (f32, f32) -> f32 that adds its two parameters.""" c = self._NewComputation("add_param0_by_param1") @@ -1507,8 +1647,25 @@ class EmbeddedComputationsTest(LocalComputationTest): execution.join() self.assertEqual(want, got) + def testScatter(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + scatter_indices = np.array([0, 2], dtype=np.int32) + updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) + + dnums = xla_client.xla_data_pb2.ScatterDimensionNumbers() + dnums.update_window_dims.append(1) + dnums.inserted_window_dims.append(0) + dnums.scatter_dims_to_operand_dims.append(0) + dnums.index_vector_dim = 1 + + c = self._NewComputation() + c.Scatter(c.Constant(a), c.Constant(scatter_indices), c.Constant(updates), + self._CreateBinaryAddS32Computation(), dnums) + expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32) + self._ExecuteAndCompareClose(c, expected=expected) + -class ErrorTest(LocalComputationTest): +class ErrorTest(ComputationTest): def setUp(self): self.f32_scalar_2 = NumpyArrayF32(2.0) @@ -1525,7 +1682,7 @@ class ErrorTest(LocalComputationTest): lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2])) -class ComputationRootTest(LocalComputationTest): +class ComputationRootTest(ComputationTest): """Tests related to setting the root of the computation.""" def testComputationRootDifferentFromLastOp(self): diff --git a/tensorflow/compiler/xla/python/xla_data.i b/tensorflow/compiler/xla/python/xla_data.i new file mode 100644 index 0000000000000000000000000000000000000000..974f314af24f61c0015a8d51c16dff1bfc84c7cc --- /dev/null +++ b/tensorflow/compiler/xla/python/xla_data.i @@ -0,0 +1,654 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// SWIG typemaps and declarations for building, compiling, and +// executing XLA computations, wrapping most of what is declared in +// xla_data.h. +// +// The typemaps below implement/assert the following correspondences +// (with elaborations below): +// +// C++ Python +// -------------------------------------+--------------------------------------- +// Span <- sequence of int +// vector -> sequence of int +// Span <- sequence of LocalOp +// Literal <-> (nested tuple of) numpy ndarray +// std::vector <- sequence of (nested tuple of) ndarray +// Shape -> pair holding (dtype, dimensions) +// <- object duck-typed as xla_client.Shape +// ProgramShape -> pair of ([arg_shapes], ret_shape) +// std::vector <- sequence of xla_client.Shape objects +// PrimitiveType <- int +// Span> <- sequence of int pairs +// PaddingConfig proto <- corresponding Python proto +// ConvolutionDimensionNumbers proto <- corresponding Python proto +// DotDimensionNumbers proto <- corresponding Python proto +// GatherDimensionNumbers proto <- corresponding Python proto +// ScatterDimensionNumbers proto <- corresponding Python proto +// Span <- sequence of ReplicaGroup Python proto +// +// Arrows indicate whether a conversion only ever occurs in one +// direction, or whether it is maintained bidirectionally. +// +// The Python objects corresponding to C++ Literals have the type: +// +// T = ndarray | (T, ...) +// +// where a terminal numpy ndarray translates to a Literal with a +// non-tuple Shape, an XLA primitive element type corresponding to the +// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates +// to a tuple-shaped Literal whose tuple components are translated +// recursively. For example, if x is a numpy ndarray in Python, with +// shape (2, 3) and dtype of dtype('float32'), then x translates to a +// Literal with rank 2, dimension 2 and 3, and XLA primitive type +// F32. Meanwhile, +// +// (x, (x, x), (x,)), +// +// translates to a tuple-shaped XLA Literal, whose component subshapes +// are a 2x3 F32-shaped literal followed by two tuple-shaped literals. +// +// Shapes output by C++ become Python objects with the type: +// +// T = (dtype, S) +// S = DIMENSIONS | TUPLE_SHAPES +// DIMENSIONS = (int, ...) +// TUPLE_SHAPES = (T, ...) +// +// In the pair described by the T rule, the terminal dtype determines +// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is +// dtype('O'), numpy's object dtype, the structure represents a tuple +// shape and the expansion of the non-terminal S is +// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type +// and S expands into DIMENSIONS giving dimension sizes. For example: +// +// (dtype('float32'), (3, 5, 7)) +// +// describes a 3x5x7 array of F32s, and +// +// (dtype('O'), ((dtype('float32'), (2, 3)), +// (dtype('float64'), (4, 5)))) +// +// describes a tuple shape with two subshapes: the first a 2x3 F32, +// and the other a 4x5 F64. +// +// The Python int corresponding to a PrimitiveType enum must be valid +// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32). +// +// The SWIG object wrappers generated by this file are not intended +// for end use, but rather for internal use in the Python XLA client, +// xla_client.py. +// +// One central reason for the Python-side indirection is that the +// Python-side objects produced by the typemaps in this file are +// further packaged up by xla_client before being passed on. For +// instance, the Python pair produced for a C++ Shape is further +// wrapped in a Python class (xla_client.Shape) so as not to expose +// the raw pair externally. +// +// Other SWIG object wrappers (e.g. of Computation) are further +// wrapped by xla_client in order to set up a custom destructor that +// triggers memory deallocation on the C++ side. + +%module(threads="1") xla_data + +// Keep the GIL except where explicitly specified. +%nothread; + +%include "tensorflow/python/platform/base.i" + +%{ +// Must be included first +#include "tensorflow/python/lib/core/numpy.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/python/numpy_bridge.h" + +using namespace xla; +using namespace xla::swig; + +%} + +// Basic types + + +%typemap(out) std::vector { + PyObject* out = PyList_New($1.size()); + for (int i = 0; i < $1.size(); ++i) { + PyList_SET_ITEM(out, i, PyInt_FromLong($1[i])); + } + $result = out; +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = PyBool_FromLong($1.ConsumeValueOrDie()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = PyString_FromString($1.ConsumeValueOrDie().c_str()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) Status { + if (!$1.ok()) { + PyErr_SetString( + PyExc_RuntimeError, $1.ToString().c_str()); + SWIG_fail; + } + Py_INCREF(Py_None); + $result = Py_None; +} + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.resize(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + PyObject* py_int = numpy::PyNumberToPyInt(o); + if (!py_int) { + PyErr_SetString( + PyExc_TypeError, + "Argument sequence element cannot be converted to int"); + Py_DECREF(o); + SWIG_fail; + } + temps[i] = numpy::PyIntOrPyLongToLong(py_int); + if (temps[i] == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + Py_DECREF(o); + SWIG_fail; + } + Py_DECREF(py_int); + Py_DECREF(o); + } + $1 = temps; +} + +// Literal + +%typemap(in) const Literal& (StatusOr literal_status) { + literal_status = numpy::XlaLiteralFromPyObject($input); + if (!literal_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); + SWIG_fail; + } + $1 = &literal_status.ValueOrDie(); +} + +%typemap(out) Literal (StatusOr obj_status) { + obj_status = numpy::PyObjectFromXlaLiteral(*$1); + if (!obj_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); + SWIG_fail; + } + $result = obj_status.ValueOrDie().release(); +} + +%typemap(out) StatusOr (StatusOr obj_status) { + if (!$1.ok()) { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } + obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); + if (!obj_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); + SWIG_fail; + } + $result = obj_status.ValueOrDie().release(); +} + +%typemap(in) const std::vector& (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + StatusOr literal_status = numpy::XlaLiteralFromPyObject(o); + if (!literal_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); + Py_DECREF(o); + SWIG_fail; + } + temps.push_back(literal_status.ConsumeValueOrDie()); + Py_DECREF(o); + } + $1 = &temps; +} + +// OpMetadata + +%typemap(in) const OpMetadata& (OpMetadata temp) { + StatusOr statusor = numpy::OpMetadataFromPyObject($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; +} + +// Shape + +%typemap(out) const Shape& { + $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()).release(); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = numpy::PyProgramShapeInfoFromXlaProgramShape( + $1.ConsumeValueOrDie()).release(); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + +%typemap(in) const Shape& (Shape temp) { + StatusOr statusor = numpy::XlaShapeFromPyShape($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; +} + +%typemap(in) const absl::optional& ( + absl::optional temp) { + if ($input == Py_None) { + temp = absl::nullopt; + $1 = &temp; + } else { + StatusOr statusor = numpy::XlaShapeFromPyShape($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; + } +} + +%typemap(out) std::unique_ptr { + $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); +} + +%typemap(in) const std::vector& (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + Py_DECREF(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temps.push_back(statusor.ConsumeValueOrDie()); + } + $1 = &temps; +} + +%typemap(in) const std::vector >& ( + std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (o == Py_None) { + temps.push_back(absl::nullopt); + } else { + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + Py_DECREF(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temps.push_back(statusor.ConsumeValueOrDie()); + } + } + $1 = &temps; +} + +// PrimitiveType + +%typemap(in) PrimitiveType { + PyObject* py_int = numpy::PyNumberToPyInt($input); + if (!py_int) { + PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); + SWIG_fail; + } + const long value = numpy::PyIntOrPyLongToLong(py_int); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + SWIG_fail; + } + if (!PrimitiveType_IsValid(value)) { + PyErr_SetString( + PyExc_TypeError, "Argument not valid for PrimitiveType enum"); + Py_DECREF(py_int); + SWIG_fail; + } + $1 = static_cast(value); +} + +// Span> + +%typemap(in) absl::Span > + (std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (!o) { + SWIG_fail; + } + PyObject* first = PyTuple_GetItem(o, 0); + if (!first) { + Py_DECREF(o); + SWIG_fail; + } + PyObject* first_pyint = numpy::PyNumberToPyInt(first); + if (!first_pyint) { + PyErr_SetString( + PyExc_TypeError, + "First pair item cannot be converted to int"); + Py_DECREF(o); + SWIG_fail; + } + PyObject* second = PyTuple_GetItem(o, 1); + if (!second) { + Py_DECREF(o); + Py_DECREF(first_pyint); + SWIG_fail; + } + PyObject* second_pyint = numpy::PyNumberToPyInt(second); + if (!second_pyint) { + PyErr_SetString( + PyExc_TypeError, + "Second pair item cannot be converted to int"); + Py_DECREF(o); + Py_DECREF(first_pyint); + SWIG_fail; + } + const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); + if (first_value == -1 && PyErr_Occurred()) { + Py_DECREF(o); + Py_DECREF(first_pyint); + Py_DECREF(second_pyint); + SWIG_fail; + } + const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); + if (second_value == -1 && PyErr_Occurred()) { + Py_DECREF(o); + Py_DECREF(first_pyint); + Py_DECREF(second_pyint); + SWIG_fail; + } + temps.push_back(std::make_pair(first_value, second_value)); + Py_DECREF(o); + } + $1 = temps; +} + +// DotDimensionNumbers + +%typemap(in) const DotDimensionNumbers& + (DotDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "lhs_contracting_dimensions", + dimension_numbers.mutable_lhs_contracting_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "rhs_contracting_dimensions", + dimension_numbers.mutable_rhs_contracting_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "lhs_batch_dimensions", + dimension_numbers.mutable_lhs_batch_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "rhs_batch_dimensions", + dimension_numbers.mutable_rhs_batch_dimensions())) { + SWIG_fail; + } + + $1 = &dimension_numbers; +} + +// PaddingConfig + +%typemap(in) const PaddingConfig& + (PaddingConfig padding_config) { + PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); + if (!dimensions) { + SWIG_fail; + } + + int length = PySequence_Size(dimensions); + if (length == -1) { + Py_DECREF(dimensions); + SWIG_fail; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(dimensions, i); + if (!item) { + Py_DECREF(dimensions); + SWIG_fail; + } + int64 edge_padding_low, edge_padding_high, interior_padding; + if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) + || !GetIntAttr(item, "edge_padding_high", &edge_padding_high) + || !GetIntAttr(item, "interior_padding", &interior_padding)) { + Py_DECREF(item); + Py_DECREF(dimensions); + SWIG_fail; + } + Py_DECREF(item); + + PaddingConfig::PaddingConfigDimension* dimension = + padding_config.add_dimensions(); + dimension->set_edge_padding_low(edge_padding_low); + dimension->set_edge_padding_high(edge_padding_high); + dimension->set_interior_padding(interior_padding); + } + Py_DECREF(dimensions); + + $1 = &padding_config; +} + +// ConvolutionDimensionNumbers + +%typemap(in) const ConvolutionDimensionNumbers& + (ConvolutionDimensionNumbers dimension_numbers) { + int64 value; + + if (!GetIntAttr($input, "input_batch_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_input_batch_dimension(value); + + if (!GetIntAttr($input, "input_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_input_feature_dimension(value); + + if (!GetIntAttr($input, "output_batch_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_output_batch_dimension(value); + + if (!GetIntAttr($input, "output_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_output_feature_dimension(value); + + if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_kernel_output_feature_dimension(value); + + if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_kernel_input_feature_dimension(value); + + if (!HandleRepeatedInt64Attribute( + $input, "input_spatial_dimensions", + dimension_numbers.mutable_input_spatial_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "kernel_spatial_dimensions", + dimension_numbers.mutable_kernel_spatial_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "output_spatial_dimensions", + dimension_numbers.mutable_output_spatial_dimensions())) { + SWIG_fail; + } + + $1 = &dimension_numbers; +} + +// GatherDimensionNumbers + +%typemap(in) const GatherDimensionNumbers& + (GatherDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "offset_dims", + dimension_numbers.mutable_offset_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "collapsed_slice_dims", + dimension_numbers.mutable_collapsed_slice_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "start_index_map", + dimension_numbers.mutable_start_index_map())) { + SWIG_fail; + } + + int64 value; + if (!GetIntAttr($input, "index_vector_dim", &value)) { + SWIG_fail; + } + dimension_numbers.set_index_vector_dim(value); + + $1 = &dimension_numbers; +} + +// ScatterDimensionNumbers + +%typemap(in) const ScatterDimensionNumbers& + (ScatterDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "update_window_dims", + dimension_numbers.mutable_update_window_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "inserted_window_dims", + dimension_numbers.mutable_inserted_window_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "scatter_dims_to_operand_dims", + dimension_numbers.mutable_scatter_dims_to_operand_dims())) { + SWIG_fail; + } + + int64 value; + if (!GetIntAttr($input, "index_vector_dim", &value)) { + SWIG_fail; + } + dimension_numbers.set_index_vector_dim(value); + + $1 = &dimension_numbers; +} + +// Span + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + ReplicaGroup rgrp; + if (!HandleRepeatedInt64Attribute( + o, "replica_ids", + rgrp.mutable_replica_ids())) { + SWIG_fail; + } + temps.push_back(rgrp); + Py_DECREF(o); + } + $1 = temps; +} diff --git a/tensorflow/compiler/xla/python/xrt.cc b/tensorflow/compiler/xla/python/xrt.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c55abc17f87c369e3d5b2140a84014e07921a9a --- /dev/null +++ b/tensorflow/compiler/xla/python/xrt.cc @@ -0,0 +1,297 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/xrt.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h" +#include "tensorflow/compiler/xrt/xrt.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace swig { + +XrtAllocation::XrtAllocation(int64 handle, Shape shape, + const string& session_target) + : handle_(handle), shape_(shape), session_target_(session_target) {} + +XrtAllocation::~XrtAllocation() { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto allocation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto release = + tensorflow::ops::XRTReleaseAllocationHandle(root, allocation_handle); + if (!root.status().ok()) { + LOG(ERROR) << root.status(); + return; + } + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({allocation_handle, handle()}); + std::vector outputs; + auto status = session.Run(inputs, {}, {release}, &outputs); + if (!status.ok()) { + LOG(ERROR) << status; + return; + } +} + +/* static */ +StatusOr XrtAllocation::FromLiteral( + const Literal& argument, const string& session_target) { + xrt::XLAAllocation alloc; + *alloc.mutable_value() = argument.ToProto(); + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto literal_string = + tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({literal_string, alloc.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs)); + + int64 handle = outputs[0].scalar()(); + return new XrtAllocation(handle, argument.shape(), session_target); +} + +const int64 XrtAllocation::handle() const { return handle_; } + +const Shape& XrtAllocation::shape() const { return shape_; } + +StatusOr XrtAllocation::ToLiteral() const { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto allocation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({allocation_handle, handle()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {read_literal}, &outputs)); + + xla::LiteralProto response; + TF_RET_CHECK(response.ParseFromString(outputs[0].scalar()())); + return Literal::CreateFromProto(response); +} + +XrtAllocationTuple::XrtAllocationTuple(std::vector elements) + : elements_(std::move(elements)) { + for (auto* element : elements_) { + CHECK(element != nullptr); + } +} + +XrtAllocationTuple::~XrtAllocationTuple() { + for (XrtAllocation* element : elements_) { + if (element != nullptr) { + delete element; + } + } +} + +StatusOr XrtAllocationTuple::Release(int i) { + XrtAllocation* element = elements_[i]; + if (element == nullptr) { + return InvalidArgument("Attempted to release already-released element %d.", + i); + } + elements_[i] = nullptr; + return element; +} + +int64 XrtAllocationTuple::size() const { return elements_.size(); } + +StatusOr XrtExecutable::CompileForXrt( + const string& hlo_module_proto, const std::vector& argument_shapes, + const Shape& result_shape, const string& session_target) { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto compile = tensorflow::ops::XRTCompile(root, program); + TF_RETURN_IF_ERROR(root.status()); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + ProgramShape program_shape; + for (auto& shape : argument_shapes) { + *program_shape.add_parameters() = shape; + } + *program_shape.mutable_result() = result_shape; + + LayoutUtil::SetToDefaultLayout(&program_shape); + *config->mutable_program_shape() = program_shape.ToProto(); + c.mutable_hlo_snapshot() + ->mutable_hlo() + ->mutable_hlo_module() + ->ParsePartialFromString(hlo_module_proto); + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({program, c.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {compile.handle}, &outputs)); + + int64 handle = outputs[0].scalar()(); + return new XrtExecutable(program_shape, handle, session_target); +} + +XrtExecutable::XrtExecutable(const ProgramShape& program_shape, int64 handle, + const string& session_target) + : program_shape_(program_shape), + handle_(handle), + session_target_(session_target) {} + +XrtExecutable::~XrtExecutable() { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto computation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto release = + tensorflow::ops::XRTReleaseCompilationHandle(root, computation_handle); + if (!root.status().ok()) { + LOG(ERROR) << root.status(); + return; + } + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({computation_handle, handle()}); + std::vector outputs; + auto status = session.Run(inputs, {}, {release}, &outputs); + if (!status.ok()) { + LOG(ERROR) << status; + return; + } +} + +StatusOr XrtExecutable::Execute( + absl::Span argument_handles) { + const int num_expected_arguments = program_shape().parameters().size(); + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + std::vector arguments; + arguments.reserve(num_expected_arguments); + for (int i = 0; i < num_expected_arguments; ++i) { + arguments.push_back( + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64)); + } + auto computation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto execution_config = + tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto execute = tensorflow::ops::XRTExecute(root, computation_handle, + execution_config, arguments); + TF_RETURN_IF_ERROR(root.status()); + + TF_RET_CHECK(argument_handles.size() == arguments.size()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(false); + e.set_release_compilation_handle(false); + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + for (int i = 0; i < arguments.size(); ++i) { + inputs.insert({arguments[i], argument_handles[i]->handle()}); + } + inputs.insert({computation_handle, handle()}); + inputs.insert({execution_config, e.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs)); + + int64 output = outputs[0].scalar()(); + return new XrtAllocation(output, program_shape().result(), session_target_); +} + +const ProgramShape& XrtExecutable::program_shape() const { + return program_shape_; +} + +int64 XrtExecutable::handle() const { return handle_; } + +void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; } + +void DeleteXrtExecutable(XrtExecutable* computation) { delete computation; } + +StatusOr DestructureXrtAllocationTuple( + XrtAllocation* allocation, const string& session_target) { + const Shape& tuple_shape = allocation->shape(); + + if (!tuple_shape.IsTuple()) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString(tuple_shape)); + } + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto base_handle = tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto shape_index = tensorflow::ops::Placeholder(root, tensorflow::DT_INT32); + auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + std::vector results; + for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + inputs.clear(); + inputs.insert({base_handle, allocation->handle()}); + inputs.insert({shape_index, {i}}); + std::vector outputs; + auto status = session.Run(inputs, {subtuple}, &outputs); + if (!status.ok()) { + // Clean up before returning non-ok status. + for (int j = 0; j < results.size(); ++j) { + delete results[j]; + } + return status; + } + const int64 subtuple_handle = outputs[0].scalar()(); + const Shape& subtuple_shape = + ShapeUtil::GetTupleElementShape(tuple_shape, i); + results.push_back( + new XrtAllocation(subtuple_handle, subtuple_shape, session_target)); + } + return new XrtAllocationTuple(std::move(results)); +} + +} // namespace swig +} // namespace xla diff --git a/tensorflow/compiler/xla/python/xrt.h b/tensorflow/compiler/xla/python/xrt.h new file mode 100644 index 0000000000000000000000000000000000000000..dd5bba6d5c9641dadc323f70745e870c14543321 --- /dev/null +++ b/tensorflow/compiler/xla/python/xrt.h @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape.h" + +namespace xla { +namespace swig { + +// Represents a reference to literals that live in a device-allocated buffer via +// XRT. Specifically, wraps an int64 handle produced by running the allocation +// graph, and an XLA shape to track the referent's shape. +class XrtAllocation { + public: + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which allocation and deallocation + // graphs are run. + static StatusOr FromLiteral(const Literal& argument, + const string& session_target); + + XrtAllocation(int64 handle, Shape shape, const string& session_target); + ~XrtAllocation(); + StatusOr ToLiteral() const; + const Shape& shape() const; + const int64 handle() const; + + private: + const int64 handle_; + const Shape shape_; + const string session_target_; +}; + +// Result of a tuple destructuring operation on an XrtAllocation. +class XrtAllocationTuple { + public: + // Note: any XrtAllocation elements that are not Release()'d will be + // deallocated in the destructor. + explicit XrtAllocationTuple(std::vector elements); + + ~XrtAllocationTuple(); + + // Releases the ith element to the caller. Further attempts to release the ith + // element will return an invalid argument error. + StatusOr Release(int i); + + // Returns the number of elements in the destructured tuple. + int64 size() const; + + private: + std::vector elements_; +}; + +// Destructures a tuple-valued XrtAllocation into its constitutent elements +// in XrtAllocationTuple form. +// +// Accepts a `session_target` argument, used in constructing the +// `tensorflow::ClientSession` instance in which the sub-tupling graph is run, +// and passed along in constructing each constituent XrtAllocation. +StatusOr DestructureXrtAllocationTuple( + XrtAllocation* allocation, const string& session_target); + +// Represents a compiled computation that can be executed given handles to +// device-allocated literals. Specifically, wraps an XRT computation handle. +class XrtExecutable { + public: + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which the compilation graph is run. + static StatusOr CompileForXrt( + const string& hlo_module_proto, const std::vector& argument_shapes, + const Shape& result_shape, const string& session_target); + + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which the execution graph is run. + XrtExecutable(const ProgramShape& program_shape, int64 handle, + const string& session_target); + ~XrtExecutable(); + + std::vector DeviceOrdinals() const { return {0}; } + + StatusOr Execute( + absl::Span argument_handles); + + const ProgramShape& program_shape() const; + int64 handle() const; + + private: + const ProgramShape program_shape_; + const int64 handle_; + const string session_target_; +}; + +// Functions for freeing resources from the Python side. +void DeleteXrtAllocation(XrtAllocation* allocation); +void DeleteXrtExecutable(XrtExecutable* computation); + +} // namespace swig +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ diff --git a/tensorflow/compiler/xla/python/xrt.i b/tensorflow/compiler/xla/python/xrt.i new file mode 100644 index 0000000000000000000000000000000000000000..456dd7be86e479b46815fc16b51a10431fe2060d --- /dev/null +++ b/tensorflow/compiler/xla/python/xrt.i @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Wrappers for XRT ops. + +%module(threads="1") xrt + +// Keep the GIL except where explicitly specified. +%nothread; + +%include "tensorflow/python/platform/base.i" +%include "tensorflow/compiler/xla/python/xla_data.i" + +%{ +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/python/xrt.h" + +using namespace xla; +using namespace xla::swig; + +%} + +// Computation and buffer/allocation types + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::XrtExecutable*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::XrtAllocation*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::XrtAllocationTuple*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + XrtAllocation* xrta; + if ((SWIG_ConvertPtr(o, (void**) &xrta, $descriptor(xla::swig::XrtAllocation*), + SWIG_POINTER_EXCEPTION)) == -1) { + SWIG_fail; + } + temps.push_back(xrta); + Py_DECREF(o); + } + $1 = temps; +} + + +%ignoreall +%unignore xla; +%unignore xla::swig; +%unignore xla::swig::XrtAllocation; +%unignore xla::swig::XrtAllocation::FromLiteral; +%unignore xla::swig::XrtAllocation::ToLiteral; +%unignore xla::swig::XrtAllocation::shape; +%unignore xla::swig::XrtAllocationTuple; +%unignore xla::swig::XrtAllocationTuple::Release; +%unignore xla::swig::XrtAllocationTuple::size; +%unignore xla::swig::XrtExecutable; +%unignore xla::swig::XrtExecutable::CompileForXrt; +%unignore xla::swig::XrtExecutable::DeviceOrdinals; +%unignore xla::swig::XrtExecutable::Execute; +%unignore xla::swig::DestructureXrtAllocationTuple; +%unignore xla::swig::DeleteXrtAllocation; +%unignore xla::swig::DeleteXrtExecutable; + +%thread; +%include "tensorflow/compiler/xla/python/xrt.h" +%nothread; + +%unignoreall diff --git a/tensorflow/compiler/xla/python_api/xla_literal.py b/tensorflow/compiler/xla/python_api/xla_literal.py index 757e41a78ad2b57d2ef6e1f3055160be22c7b3ed..19bd685ab2260485d2a86f0a682d0cdd36712fdb 100644 --- a/tensorflow/compiler/xla/python_api/xla_literal.py +++ b/tensorflow/compiler/xla/python_api/xla_literal.py @@ -69,7 +69,7 @@ def _ConvertNumpyArrayToLiteral(ndarray): if ndarray.ndim == 0: getattr(literal, type_record.literal_field_name).append( - _np.asscalar(ndarray.astype(type_record.literal_field_type))) + ndarray.astype(type_record.literal_field_type).item()) else: # Ndarrays with boolean dtypes need special type conversion with protobufs if ndarray.dtype in {_np.bool_, _np.dtype('bool')}: diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py index 95b2bf300ec67e9f034f77450416544cb088ae55..bdcd4abd6cc708795416b15412f37dde10d7fe97 100644 --- a/tensorflow/compiler/xla/python_api/xla_shape.py +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -20,6 +20,8 @@ from __future__ import print_function import numpy as _np # Avoids becoming a part of public Tensorflow API. +from six.moves import xrange + from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python_api import types diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index ceb5e74db7c3b9305e9d77068df9ae0a3690af8a..08b78ee244844f41d551d7e249cec0cbf157d639 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -32,48 +32,19 @@ limitations under the License. namespace xla { -namespace { - -template -std::unique_ptr> MatmulArray2DImpl( - const Array2D& lhs, const Array2D& rhs, - const std::function& impl_fn) { - CHECK_EQ(lhs.width(), rhs.height()); - int m = lhs.height(); - int n = rhs.width(); - int k = lhs.width(); - auto result = absl::make_unique>(m, n); - // Because Eigen is a header-oriented library, make sure that the Eigen code - // is the same as the code used by the CPU backend (otherwise the linker will - // randomly pick *some* definition). - impl_fn( - /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, - k, - /*transpose_lhs=*/0, - /*transpose_rhs=*/0); - return result; -} - -} // namespace - /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::Array2DF32ToF64( @@ -557,10 +528,11 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( dim2.set_base_dilation(lhs_dilation.second); *window.add_dimensions() = dim2; - const Shape& shape = ShapeInference::InferConvolveShape( - lhs_literal.shape(), rhs_literal.shape(), - /*feature_group_count=*/1, window, dnums) - .ConsumeValueOrDie(); + const Shape& shape = + ShapeInference::InferConvolveShape( + lhs_literal.shape(), rhs_literal.shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums) + .ConsumeValueOrDie(); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -572,16 +544,16 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( /*new_size=*/2, PrecisionConfig::DEFAULT); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, precision_config)); + /*batch_group_count=*/1, window, dnums, precision_config)); HloModuleConfig config; HloModule module("ReferenceUtil", config); auto computation = module.AddEntryComputation(b.Build()); HloEvaluator evaluator; Literal result_literal = - evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); + evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); - CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4); + CHECK_EQ(result_literal.shape().rank(), 4); auto result = absl::make_unique>(result_literal.shape().dimensions(0), result_literal.shape().dimensions(1), @@ -634,24 +606,26 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& reduce_function) { std::vector result; CHECK_EQ(dims.size(), 3); - const std::set dim_set(dims.begin(), dims.end()); + const absl::flat_hash_set dim_set(dims.begin(), dims.end()); CHECK_EQ(dim_set.size(), 3); - for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) { - for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2()); + for (int64 a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1()); + ++a0) { + for (int64 a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2()); ++a1) { - for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3()); + for (int64 a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3()); ++a2) { - for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4()); + for (int64 a3 = 0; a3 == 0 || (!dim_set.contains(3) && a3 < array.n4()); ++a3) { float accumulator = init; - for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1()); - ++i0) { - for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2()); - ++i1) { + for (int64 i0 = 0; + i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) { + for (int64 i1 = 0; + i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) { for (int64 i2 = 0; - i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) { + i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) { for (int64 i3 = 0; - i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) { + i3 == 0 || (dim_set.contains(3) && i3 < array.n4()); + ++i3) { // Handle zero-sized arrays. if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 && array.n4() > 0) { diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index d8123a6de28ca532819ece4a75cd0b725f8c1bbd..22b4218fbd5e9bc59a0de22735eb51db46670f09 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -47,6 +47,14 @@ namespace xla { }); } +::grpc::Status GRPCService::GetDeviceHandles(::grpc::ServerContext* context, + const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->GetDeviceHandles(arg, result); + }); +} + ::grpc::Status GRPCService::Compile(::grpc::ServerContext* /*context*/, const CompileRequest* arg, CompileResponse* result) { @@ -61,6 +69,14 @@ namespace xla { [this, arg, result]() { return service_->Execute(arg, result); }); } +::grpc::Status GRPCService::ExecuteGraphParallel( + ::grpc::ServerContext* /*context*/, const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->ExecuteGraphParallel(arg, result); + }); +} + ::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) { diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index 3e586b288a56a22573d0c3b9ae7b2f25fdbf851a..b546704f73e34941cbf7bc2fe08062aa438039f7 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -39,6 +39,10 @@ class GRPCService : public grpc::XlaService::Service { const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; + ::grpc::Status GetDeviceHandles(::grpc::ServerContext* context, + const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; + ::grpc::Status Compile(::grpc::ServerContext* context, const CompileRequest* arg, CompileResponse* result) override; @@ -46,6 +50,9 @@ class GRPCService : public grpc::XlaService::Service { ::grpc::Status Execute(::grpc::ServerContext* context, const ExecuteRequest* arg, ExecuteResponse* result) override; + ::grpc::Status ExecuteGraphParallel(::grpc::ServerContext* context, + const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) override; ::grpc::Status WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4c21ae2a427477caa86fb4130616c38eb3bcf006..8d8394cb43ee013b9396a54e3a4d037445fcc0e1 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1,6 +1,14 @@ # Description: # XLA service implementation. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") + licenses(["notice"]) # Apache 2.0 package(default_visibility = [":friends"]) @@ -12,15 +20,6 @@ package_group( ], ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load( - "//tensorflow/core:platform/default/build_config.bzl", - "tf_proto_library_py", -) - xla_proto_library( name = "hlo_proto", srcs = ["hlo.proto"], @@ -115,6 +114,7 @@ tf_cc_test( ":bfloat16_normalization", ":bfloat16_support", ":hlo", + ":hlo_creation_utils", ":hlo_verifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -224,23 +224,28 @@ cc_library( "hlo_evaluator_typed_visitor.h", "hlo_evaluator_typed_visitor_bfloat16.cc", "hlo_evaluator_typed_visitor_bool.cc", + "hlo_evaluator_typed_visitor_complex128.cc", "hlo_evaluator_typed_visitor_complex64.cc", "hlo_evaluator_typed_visitor_double.cc", "hlo_evaluator_typed_visitor_float.cc", "hlo_evaluator_typed_visitor_half.cc", + "hlo_evaluator_typed_visitor_int16.cc", "hlo_evaluator_typed_visitor_int32.cc", "hlo_evaluator_typed_visitor_int64.cc", "hlo_evaluator_typed_visitor_int8.cc", + "hlo_evaluator_typed_visitor_uint16.cc", "hlo_evaluator_typed_visitor_uint32.cc", "hlo_evaluator_typed_visitor_uint64.cc", "hlo_evaluator_typed_visitor_uint8.cc", ], hdrs = ["hlo_evaluator.h"], deps = [ + ":dynamic_dimension_inference", ":hlo", ":hlo_casting_utils", ":hlo_query", ":shape_inference", + "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -249,12 +254,14 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", @@ -266,6 +273,7 @@ tf_cc_test( srcs = ["hlo_evaluator_test.cc"], deps = [ ":hlo", + ":hlo_element_type_converter", ":hlo_evaluator", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", @@ -278,7 +286,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -514,6 +521,7 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -672,10 +680,10 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/core:core_cpu_lib", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -694,6 +702,7 @@ cc_library( ":compiler", ":computation_layout", ":device_memory_allocator", + ":dynamic_dimension_inference", ":executable", ":execution_tracker", ":hlo", @@ -1001,6 +1010,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1012,6 +1022,7 @@ cc_library( srcs = ["name_uniquer.cc"], hdrs = ["name_uniquer.h"], deps = [ + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", @@ -1051,7 +1062,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1089,7 +1099,6 @@ cc_library( ":buffer_value_containers", ":heap_simulator", ":hlo", - ":hlo_memory_scheduler", ":hlo_proto", ":logical_buffer", ":tuple_points_to_analysis", @@ -1134,6 +1143,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -1193,7 +1203,6 @@ cc_library( ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", @@ -1228,7 +1237,6 @@ cc_library( deps = [ ":hlo", ":hlo_proto", - "//tensorflow/compiler/xla:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -1412,6 +1420,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -1451,11 +1460,15 @@ cc_library( hdrs = ["hlo_creation_utils.h"], deps = [ ":hlo", + ":hlo_module_config", ":shape_inference", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:comparators", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1495,12 +1508,25 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) +cc_library( + name = "op_expander_pass", + srcs = ["op_expander_pass.cc"], + hdrs = ["op_expander_pass.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", + ], +) + cc_library( name = "gather_expander", srcs = ["gather_expander.cc"], @@ -1509,6 +1535,7 @@ cc_library( ":hlo", ":hlo_creation_utils", ":hlo_pass", + ":op_expander_pass", ":while_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", @@ -1532,6 +1559,28 @@ cc_library( ], ) +cc_library( + name = "triangular_solve_expander", + srcs = ["triangular_solve_expander.cc"], + hdrs = ["triangular_solve_expander.h"], + deps = [ + ":op_expander_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + tf_cc_test( name = "batchnorm_expander_test", size = "small", @@ -1576,6 +1625,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -1590,7 +1642,7 @@ tf_cc_test( ":algebraic_simplifier", ":hlo", ":hlo_casting_utils", - ":hlo_matchers", + ":hlo_creation_utils", ":hlo_parser", ":hlo_pass", ":pattern_matcher", @@ -1695,9 +1747,9 @@ tf_cc_test( ) cc_library( - name = "convolution_feature_group_converter", - srcs = ["convolution_feature_group_converter.cc"], - hdrs = ["convolution_feature_group_converter.h"], + name = "convolution_group_converter", + srcs = ["convolution_group_converter.cc"], + hdrs = ["convolution_group_converter.h"], deps = [ ":hlo", ":hlo_pass", @@ -1715,11 +1767,11 @@ cc_library( ) tf_cc_test( - name = "convolution_feature_group_converter_test", + name = "convolution_group_converter_test", size = "small", - srcs = ["convolution_feature_group_converter_test.cc"], + srcs = ["convolution_group_converter_test.cc"], deps = [ - ":convolution_feature_group_converter", + ":convolution_group_converter", ":hlo", ":hlo_matchers", ":hlo_parser", @@ -1782,6 +1834,7 @@ tf_cc_test( ":hlo_cse", ":hlo_dce", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", ":hlo_pass_pipeline", ":tuple_simplifier", @@ -1860,8 +1913,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -1916,6 +1970,7 @@ cc_library( hdrs = ["dynamic_dimension_inference.h"], deps = [ ":hlo", + ":while_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1925,6 +1980,46 @@ cc_library( ], ) +cc_library( + name = "dynamic_padder", + srcs = ["dynamic_padder.cc"], + hdrs = ["dynamic_padder.h"], + deps = [ + ":dynamic_dimension_inference", + ":hlo_dce", + ":hlo_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "dynamic_padder_test", + srcs = ["dynamic_padder_test.cc"], + deps = [ + ":dynamic_padder", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "dynamic_dimension_inference_test", srcs = ["dynamic_dimension_inference_test.cc"], @@ -2011,7 +2106,6 @@ cc_library( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -2052,6 +2146,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", @@ -2108,8 +2203,12 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -2249,6 +2348,7 @@ tf_cc_test( srcs = ["hlo_dataflow_analysis_test.cc"], deps = [ ":hlo", + ":hlo_creation_utils", ":hlo_dataflow_analysis", ":hlo_graph_dumper", ":hlo_matchers", @@ -2282,6 +2382,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -2418,6 +2519,7 @@ tf_cc_test( srcs = ["tuple_points_to_analysis_test.cc"], deps = [ ":hlo", + ":hlo_creation_utils", ":hlo_matchers", ":instruction_fusion", ":tuple_points_to_analysis", @@ -2542,6 +2644,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -2586,6 +2689,7 @@ tf_cc_test( srcs = ["hlo_verifier_test.cc"], deps = [ ":hlo", + ":hlo_module_config", ":hlo_parser", ":hlo_verifier", ":layout_assignment", @@ -2593,6 +2697,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -2790,7 +2895,6 @@ cc_library( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -2963,15 +3067,11 @@ cc_library( srcs = ["hlo_get_dimension_size_rewriter.cc"], hdrs = ["hlo_get_dimension_size_rewriter.h"], deps = [ + ":dynamic_dimension_inference", ":hlo", ":hlo_pass", ":shape_inference", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", ], ) @@ -3133,43 +3233,17 @@ tf_cc_test( ], ) -cc_library( - name = "hlo_tfgraph_builder", - srcs = ["hlo_tfgraph_builder.cc"], - hdrs = ["hlo_tfgraph_builder.h"], - deps = [ - ":hlo", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "hlo_tfgraph_builder_test", - srcs = ["hlo_tfgraph_builder_test.cc"], - deps = [ - ":hlo_tfgraph_builder", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:protos_all_cc", - ], -) - cc_library( name = "hlo_graph_dumper", srcs = [ "hlo_graph_dumper.cc", + "hlo_graph_html_renderer.cc", ], hdrs = ["hlo_graph_dumper.h"], deps = [ ":hlo", ":hlo_casting_utils", ":hlo_execution_profile", - ":hlo_tfgraph_builder", ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -3179,6 +3253,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", @@ -3212,7 +3287,6 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], ) @@ -3339,7 +3413,6 @@ cc_library( ":hlo_pass_pipeline", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -3396,10 +3469,70 @@ cc_library( ":hlo_profile_printer_data", ":human_readable_profile_builder", "//tensorflow/compiler/xla:types", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", ], ) +cc_library( + name = "sort_simplifier", + srcs = ["sort_simplifier.cc"], + hdrs = ["sort_simplifier.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "sort_simplifier_test", + srcs = ["sort_simplifier_test.cc"], + deps = [ + ":hlo_matchers", + ":hlo_parser", + ":pattern_matcher", + ":pattern_matcher_gmock", + ":sort_simplifier", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "stable_sort_expander", + srcs = ["stable_sort_expander.cc"], + hdrs = ["stable_sort_expander.h"], + deps = [ + ":hlo", + ":hlo_casting_utils", + ":hlo_pass", + ":op_expander_pass", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "stable_sort_expander_test", + srcs = ["stable_sort_expander_test.cc"], + deps = [ + ":algebraic_simplifier", + ":hlo_matchers", + ":hlo_parser", + ":pattern_matcher", + ":pattern_matcher_gmock", + ":stable_sort_expander", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + cc_library( name = "tuple_util", srcs = ["tuple_util.cc"], @@ -3496,9 +3629,7 @@ cc_library( ":while_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -3553,7 +3684,6 @@ cc_library( ":hlo_evaluator", ":hlo_pass", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3567,14 +3697,16 @@ cc_library( tf_cc_test( name = "indexed_array_analysis_test", srcs = ["indexed_array_analysis_test.cc"], + extra_copts = ["-Wno-string-plus-int"], deps = [ ":hlo_matchers", + ":hlo_parser", ":indexed_array_analysis", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -3596,6 +3728,7 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:variant", ], ) @@ -3624,7 +3757,6 @@ cc_library( srcs = ["hlo_lexer.cc"], hdrs = [ "hlo_lexer.h", - "hlo_token.h", ], deps = [ "//tensorflow/compiler/xla:shape_util", @@ -3660,6 +3792,47 @@ cc_library( ], ) +cc_library( + name = "optimize_input_output_buffer_alias", + srcs = ["optimize_input_output_buffer_alias.cc"], + hdrs = ["optimize_input_output_buffer_alias.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + ], +) + +tf_cc_test( + name = "optimize_input_output_buffer_alias_test", + srcs = ["optimize_input_output_buffer_alias_test.cc"], + deps = [ + ":optimize_input_output_buffer_alias", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + ], +) + cc_library( name = "ar_crs_combiner", srcs = ["ar_crs_combiner.cc"], @@ -3669,10 +3842,10 @@ cc_library( ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "@com_google_absl//absl/container:flat_hash_map", @@ -3680,6 +3853,38 @@ cc_library( ], ) +cc_library( + name = "dynamic_index_splitter", + srcs = ["dynamic_index_splitter.cc"], + hdrs = ["dynamic_index_splitter.h"], + deps = [ + ":hlo_casting_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "dynamic_index_splitter_test", + srcs = ["dynamic_index_splitter_test.cc"], + deps = [ + ":dynamic_index_splitter", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + tf_cc_test( name = "ar_crs_combiner_test", srcs = ["ar_crs_combiner_test.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 985c5af1c4d89425dd6693585e42e22510fe21f8..d566062e7401af545bd3a097d3b3735b305eba66 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -25,6 +26,9 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" @@ -32,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -41,12 +46,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" @@ -117,23 +124,37 @@ bool TransposeIsBitcast(const HloInstruction* transpose) { transpose->dimensions()); } -// Returns true if the given reshape/copy produces a result which is bit-wise -// identical to its operand and thus may be replaced with a bitcast. -// -// This function is conservative -- even if this function returns false, the -// reshape may still be a bitcast. For example, a reshape from [28x28] to [784]. -bool ReshapeOrCopyIsBitcast( - const HloInstruction* instr, - const AlgebraicSimplifierOptions::ValidBitcastCallback& - valid_bitcast_callback) { +// Recursive helper for method below. +HloInstruction* BitcastingOperandOfReshapeOrCopyChainHelper( + HloInstruction* instr, HloInstruction* operand, + const AlgebraicSimplifierOptions& options) { + // Can't replace chain of copies and reshapes with bitcasts if the compiler + // used a memory layout which isn't compatible. + if (options.ReshapeIsBitcast(operand->shape(), instr->shape())) { + return operand; + } + + // If the operand is a copy or reshape try to see if the operand's operand + // would produce a bitcast with initial instruction. + if (HloOpcode::kReshape == operand->opcode() || + HloOpcode::kCopy == operand->opcode()) { + return BitcastingOperandOfReshapeOrCopyChainHelper( + instr, operand->mutable_operand(0), options); + } + return nullptr; +} + +// Returns an operand of a chain of reshapes and copies that is bit-wise +// identical to first reshape or copy in the chain. +HloInstruction* BitcastingOperandOfReshapeOrCopyChain( + HloInstruction* instr, const AlgebraicSimplifierOptions& options) { + if (!options.is_layout_sensitive()) { + return nullptr; + } CHECK(HloOpcode::kReshape == instr->opcode() || HloOpcode::kCopy == instr->opcode()); - - const HloInstruction* operand = instr->operand(0); - // Can't insert bitcasts if the compiler used a memory layout which isn't - // compatible. - return ShapeUtil::ReshapeIsBitcast(operand->shape(), instr->shape()) && - valid_bitcast_callback(operand->shape(), instr->shape()); + return BitcastingOperandOfReshapeOrCopyChainHelper( + instr, instr->mutable_operand(0), options); } bool IsUnstridedSlice(const HloInstruction* hlo) { @@ -200,6 +221,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandlePower(HloInstruction* power) override; + Status HandleRemainder(HloInstruction* remainder) override; + Status HandleReshape(HloInstruction* reshape) override; Status HandleReduce(HloInstruction* reduce) override; @@ -239,9 +262,16 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // more fusion than leaving the nodes as Dot operations. StatusOr HandleDotStrengthReduction(HloInstruction* dot); + // Removes dimension dim from hlo. + HloInstruction* StripDim(HloInstruction* hlo, int64 dim) { + CHECK_EQ(hlo->shape().dimensions(dim), 1); + return computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::DeleteDimension(dim, hlo->shape()), hlo)); + } + // Reshapes an instruction to rank 1 if it is not already rank 1. HloInstruction* Flatten(HloInstruction* hlo) { - if (ShapeUtil::Rank(hlo->shape()) == 1) { + if (hlo->shape().rank() == 1) { return hlo; } return computation_->AddInstruction(HloInstruction::CreateReshape( @@ -250,19 +280,58 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { hlo)); } - // Helper method to perform and add reduction in a single dimension. - HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + // Converts to primitive type if the input hlo is not that type, otherwise + // returns the original hlo. + HloInstruction* AsType(HloInstruction* hlo, + const PrimitiveType element_type) { + if (hlo->shape().element_type() == element_type) { + return hlo; + } + return computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); + } + + // Transposes a dot operand such that the batch dimensions are the msot major, + // and the contracting dimensions are most minor. + StatusOr NormalizeDotOperandToBatchMajorAndContractingMinor( + HloInstruction* dot_operand, absl::Span batch_dimensions, + absl::Span contracting_dimensions) { + std::vector transpose_dimensions(batch_dimensions.begin(), + batch_dimensions.end()); + for (int64 i = 0; i < dot_operand->shape().rank(); ++i) { + if (!(absl::c_linear_search(batch_dimensions, i) || + absl::c_linear_search(contracting_dimensions, i))) { + transpose_dimensions.push_back(i); + } + } + transpose_dimensions.insert(transpose_dimensions.end(), + contracting_dimensions.begin(), + contracting_dimensions.end()); + return MakeTransposeHlo(dot_operand, transpose_dimensions); + } + + // Helper method to perform and add reduction on a list of dimensions. + HloInstruction* AddReduce(HloInstruction* hlo, absl::Span dims) { HloInstruction* zero = computation_->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(hlo->shape().element_type()).Clone())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); - Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); + Shape shape = ShapeUtil::FilterDimensions( + [&](int64 dim) { return !absl::c_linear_search(dims, dim); }, + hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( - shape, hlo, zero, {dim}, AddReduce_computation)); + shape, hlo, zero, dims, AddReduce_computation)); + } + + HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + return AddReduce(hlo, std::vector{dim}); } - // Convenience method for replacing an instruction with a bitcast. - void ReplaceWithBitcast(HloInstruction* instruction); + // Convenience method for replacing an instruction with a bitcast. If operand + // is not null, then the bitcast will use the specified operand instead of the + // operand of the instruction. + void ReplaceWithBitcast(HloInstruction* instruction, + HloInstruction* operand = nullptr); // Replace old instruction with new instruction if old and new instructions // have the same shape. Updates uses and root instruction. Returns whether a @@ -391,17 +460,19 @@ bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, } } -void AlgebraicSimplifierVisitor::ReplaceWithBitcast( - HloInstruction* instruction) { +void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction, + HloInstruction* operand) { CHECK_EQ(1, instruction->operand_count()); + if (operand == nullptr) { + operand = instruction->mutable_operand(0); + } CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()), - ShapeUtil::ElementsIn(instruction->operand(0)->shape())); + ShapeUtil::ElementsIn(operand->shape())); CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()), - ShapeUtil::ByteSizeOf(instruction->operand(0)->shape())); + ShapeUtil::ByteSizeOf(operand->shape())); - auto bitcast = computation_->AddInstruction( - HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kBitcast, - instruction->mutable_operand(0))); + auto bitcast = computation_->AddInstruction(HloInstruction::CreateUnary( + instruction->shape(), HloOpcode::kBitcast, operand)); TF_CHECK_OK(ReplaceInstruction(instruction, bitcast)); } @@ -562,9 +633,9 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { return Status::OK(); } - if (options_.is_layout_sensitive() && - ReshapeOrCopyIsBitcast(copy, options_.valid_bitcast_callback())) { - ReplaceWithBitcast(copy); + if (HloInstruction* bitcast_operand = + BitcastingOperandOfReshapeOrCopyChain(copy, options_)) { + ReplaceWithBitcast(copy, bitcast_operand); } return Status::OK(); @@ -677,7 +748,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( return Status::OK(); } PaddingConfig padding_config; - for (int64 dim = 0; dim < ShapeUtil::Rank(operands[0]->shape()); ++dim) { + for (int64 dim = 0; dim < operands[0]->shape().rank(); ++dim) { auto padding_config_dim = padding_config.add_dimensions(); padding_config_dim->set_edge_padding_high(0); padding_config_dim->set_edge_padding_low(0); @@ -705,7 +776,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( static HloInstruction* BuildTupleConstant(HloComputation* computation, const LiteralSlice& literal) { - if (ShapeUtil::IsTuple(literal.shape())) { + if (literal.shape().IsTuple()) { std::vector elems; elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { @@ -722,7 +793,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation, Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // Tuple constants aren't directly supported by any backend. Expand them into // explicit Tuple instructions. - if (ShapeUtil::IsTuple(constant->shape())) { + if (constant->shape().IsTuple()) { return ReplaceInstruction( constant, BuildTupleConstant(computation_, constant->literal())); } @@ -744,7 +815,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { } // If a literal is an increasing sequence from zero, replace it with an iota. - if (ShapeUtil::Rank(constant->shape()) == 1 && + if (constant->shape().rank() == 1 && ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsR1Iota()) { return ReplaceWithNewInstruction( @@ -781,10 +852,82 @@ Status InvertConstant(const HloInstruction& constant, Literal* result) { return T{1.0} / constant.literal().Get(indices); }); } + +template +std::unique_ptr TryDivideToShift(HloInstruction* divide, + HloComputation* computation) { + HloInstruction *a, *b, *c; + CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); + + if (ShapeUtil::ElementIsIntegral(divide->shape()) && + !Match(b, m::ConstantEffectiveScalar(&c)) && + !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) { + return nullptr; + } + + if (ShapeUtil::ElementIsSigned(divide->shape())) { + int64 b_value = c->literal().GetFirstElement(); + if (b_value > 0 && IsPowerOfTwo(static_cast(b_value))) { + // Handle negative dividends by negating the result of the division. + HloInstruction* zero_like_a = BroadcastZeros( + computation, a->shape().element_type(), a->shape().dimensions()); + + auto* dividend_is_negative = + computation->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a, + zero_like_a)); + + auto* negated_dividend = computation->AddInstruction( + HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); + + auto* abs_dividend = + computation->AddInstruction(HloInstruction::CreateTernary( + a->shape(), HloOpcode::kSelect, dividend_is_negative, + negated_dividend, a)); + + int log2_abs_b_value = tensorflow::Log2Floor64(b_value); + + auto* shift_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(log2_abs_b_value))); + if (!ShapeUtil::IsScalar(b->shape())) { + shift_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), shift_amount, {})); + } + + auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend, + shift_amount)); + + auto* neqated_quotient = + computation->AddInstruction(HloInstruction::CreateUnary( + quotient->shape(), HloOpcode::kNegate, quotient)); + + return HloInstruction::CreateTernary(divide->shape(), HloOpcode::kSelect, + dividend_is_negative, + neqated_quotient, quotient); + } + } else { + uint64 b_value = c->literal().GetFirstElement(); + if (IsPowerOfTwo(b_value)) { + int log2_abs_b_value = tensorflow::Log2Floor64(b_value); + HloInstruction* shift_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(log2_abs_b_value))); + if (!ShapeUtil::IsScalar(b->shape())) { + shift_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), shift_amount, {})); + } + return HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kShiftRightLogical, a, shift_amount); + } + } + + return nullptr; +} } // namespace Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { - Shape* shape; HloInstruction *a, *b, *c, *d; CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); // A/1 => A @@ -793,6 +936,61 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } + // A / B => A >> log2(B) if B is a power of 2. + switch (divide->shape().element_type()) { + case S8: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case S16: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case S32: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case S64: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U8: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U16: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U32: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U64: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + default: + break; + } + + Shape* shape; // exp(A)/exp(B) => exp(A-B) if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b))) .WithShape(m::Shape(&shape)))) { @@ -833,6 +1031,24 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { divide->shape(), HloOpcode::kMultiply, a, new_power)); } + // A/sqrt(B) => A*rsqrt(X). + if (Match(divide, m::Divide(m::Op(&a), m::Sqrt(m::Op(&b))))) { + auto* rsqrt = computation_->AddInstruction( + HloInstruction::CreateUnary(divide->shape(), HloOpcode::kRsqrt, b)); + return ReplaceWithNewInstruction( + divide, HloInstruction::CreateBinary(rsqrt->shape(), + HloOpcode::kMultiply, a, rsqrt)); + } + + // A/rsqrt(B) => A*sqrt(B). + if (Match(divide, m::Divide(m::Op(&a), m::Rsqrt(m::Op(&b))))) { + auto* sqrt = computation_->AddInstruction( + HloInstruction::CreateUnary(divide->shape(), HloOpcode::kSqrt, b)); + return ReplaceWithNewInstruction( + divide, HloInstruction::CreateBinary(sqrt->shape(), + HloOpcode::kMultiply, a, sqrt)); + } + // Simplifying integral division would produce unexpected results. if (ShapeUtil::ElementIsIntegral(divide->shape())) { return Status::OK(); @@ -843,8 +1059,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // (Backends can do this transformation, but generally only if the constant is // a scalar.) if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { - Literal new_literal(b->shape()); - switch (b->shape().element_type()) { + Shape result_shape = b->literal().shape(); + Literal new_literal(result_shape); + switch (result_shape.element_type()) { case F16: TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); break; @@ -860,6 +1077,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { case C64: TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); break; + case C128: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; default: return Status::OK(); } @@ -908,32 +1128,54 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - int64 lhs_collapsing_dim = - dot->dot_dimension_numbers().lhs_contracting_dimensions(0); + + const auto kept_dim = [](int64 rank, int64 contracting_dimension, + absl::Span batch_dimensions) -> int64 { + for (int64 i = 0; i < rank; ++i) { + if (i != contracting_dimension && + !absl::c_linear_search(batch_dimensions, i)) { + return i; + } + } + return -1; + }; + + const int64 dot_rank = dot->shape().rank(); + const int64 rhs_rank = rhs->shape().rank(); + const int64 lhs_rank = lhs->shape().rank(); + const auto& dnums = dot->dot_dimension_numbers(); + if (dnums.rhs_contracting_dimensions_size() != 1) { + return false; + } + if (dot_rank > 2 && (lhs_rank != rhs_rank || lhs_rank != dot_rank)) { + return false; + } + int64 lhs_collapsing_dim = dnums.lhs_contracting_dimensions(0); + int64 lhs_kept_dim = kept_dim(lhs_rank, lhs_collapsing_dim, + AsInt64Slice(dnums.lhs_batch_dimensions())); + // If there is no non-contracting dimension in rank 2, do not strength reduce. + if (lhs_kept_dim == -1 && lhs_rank > 1) { + return false; + } if (lhs->IsRank2Transpose()) { lhs = lhs->mutable_operand(0); - lhs_collapsing_dim = 1 - lhs_collapsing_dim; + std::swap(lhs_collapsing_dim, lhs_kept_dim); } - const int64 lhs_kept_dim = 1 - lhs_collapsing_dim; - int64 rhs_collapsing_dim = - dot->dot_dimension_numbers().rhs_contracting_dimensions(0); + int64 rhs_collapsing_dim = dnums.rhs_contracting_dimensions(0); + int64 rhs_kept_dim = kept_dim(rhs_rank, rhs_collapsing_dim, + AsInt64Slice(dnums.rhs_batch_dimensions())); + // If there is no non-contracting dimension in rank 2, do not strength reduce. + if (rhs_kept_dim == -1 && rhs_rank > 1) { + return false; + } if (rhs->IsRank2Transpose()) { rhs = rhs->mutable_operand(0); - rhs_collapsing_dim = 1 - rhs_collapsing_dim; + std::swap(rhs_collapsing_dim, rhs_kept_dim); } - const int64 rhs_kept_dim = 1 - rhs_collapsing_dim; - - auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) { - if (hlo->shape().element_type() == element_type) { - return hlo; - } - return computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); - }; auto reshape_if_necessary = [&](HloInstruction* hlo) { - hlo = as_type(hlo, dot->shape().element_type()); + hlo = AsType(hlo, dot->shape().element_type()); if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { hlo = computation_->AddInstruction( HloInstruction::CreateReshape(dot->shape(), hlo)); @@ -942,13 +1184,18 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( }; auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) { - return AddReduce(as_type(hlo, F32), dim); + return AddReduce(AsType(hlo, F32), dim); + }; + + auto broadcast = [&](HloInstruction* hlo, const Shape& shape, + absl::Span dims) { + return computation_->AddInstruction( + HloInstruction::CreateBroadcast(shape, hlo, dims)); }; auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, int64 dim) { - return computation_->AddInstruction( - HloInstruction::CreateBroadcast(shape, hlo, {dim})); + return broadcast(hlo, shape, {dim}); }; auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) { @@ -959,11 +1206,9 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // Strength reduce dot(a[K] , b[K]) = // reshape(result.shape, // reduce_sum(multiply(a, b), {0})) - if (ShapeUtil::Rank(rhs->shape()) == 1 && - ShapeUtil::Rank(lhs->shape()) == 1) { - TF_RETURN_IF_ERROR( - ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( - multiply(Flatten(lhs), Flatten(rhs)), 0)))); + if (rhs_rank == 1 && lhs_rank == 1) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, rhs), 0)))); return true; } @@ -977,8 +1222,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // Simplify outer product into multiply with implicit broadcasting. // // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) - if (ShapeUtil::Rank(rhs->shape()) == 2 && - rhs->shape().dimensions(rhs_collapsing_dim) == 1) { + if (rhs_rank == 2 && rhs->shape().dimensions(rhs_collapsing_dim) == 1) { TF_RETURN_IF_ERROR(ReplaceInstruction( dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0), broadcast_to_dim(Flatten(rhs), dot->shape(), 1)))); @@ -992,10 +1236,9 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // {0}) // ) // ) - if (ShapeUtil::Rank(lhs->shape()) == 1 || - (ShapeUtil::Rank(lhs->shape()) == 2 && - lhs->shape().dimensions(lhs_kept_dim) == 1)) { - if (ShapeUtil::Rank(rhs->shape()) == 1) { + if (lhs_rank == 1 || + (lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) { + if (rhs->shape().rank() == 1) { TF_RETURN_IF_ERROR( ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( multiply(Flatten(lhs), rhs), 0)))); @@ -1014,9 +1257,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // reshape(result.shape, // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) // ) - if (ShapeUtil::Rank(rhs->shape()) == 1 || - (ShapeUtil::Rank(rhs->shape()) == 2 && - rhs->shape().dimensions(rhs_kept_dim) == 1)) { + if (rhs_rank == 1 || + (rhs_rank == 2 && rhs->shape().dimensions(rhs_kept_dim) == 1)) { TF_RETURN_IF_ERROR(ReplaceInstruction( dot, reshape_if_necessary(add_reduce_in_f32( multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), @@ -1024,6 +1266,97 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( lhs_collapsing_dim)))); return true; } + + // Only consider kDot with batch dimension. + if (dot_rank <= 2) { + return false; + } + + CHECK_EQ(rhs_rank, lhs_rank); + CHECK_EQ(dot_rank, lhs_rank); + // If there is more than one non-contracting dimension or the batch dimensions + // are not equal, bail out since transposes may be required to do a strength + // reduction. + if (dnums.rhs_batch_dimensions_size() + 2 != dot_rank || + !absl::c_equal(dnums.lhs_batch_dimensions(), + dnums.rhs_batch_dimensions())) { + return false; + } + + auto broadcast_dims = [](int64 rank, int64 non_broadcast_dim) { + absl::InlinedVector dims; + for (int64 i = 0; i < rank; ++i) { + if (i != non_broadcast_dim) { + dims.push_back(i); + } + } + return dims; + }; + + // If the contracting dimension is 1, remove the degnerate dimnensions from + // the lhs and rhs, broadcast each to the result shape and multiply. + if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 && + (rhs_kept_dim == rhs_rank - 1 || + (rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) { + CHECK_EQ(rhs->shape().dimensions(rhs_collapsing_dim), 1); + const int64 lhs_kept_dim_in_output = + lhs_kept_dim > lhs_collapsing_dim ? (lhs_kept_dim - 1) : lhs_kept_dim; + absl::InlinedVector lhs_broadcast_dims; + for (const int64 dim : dnums.lhs_batch_dimensions()) { + lhs_broadcast_dims.push_back(dim > lhs_collapsing_dim ? (dim - 1) : dim); + } + absl::InlinedVector rhs_broadcast_dims = lhs_broadcast_dims; + lhs_broadcast_dims.push_back(lhs_kept_dim_in_output); + absl::c_sort(lhs_broadcast_dims); + rhs_broadcast_dims.push_back(dot_rank - 1); + absl::c_sort(rhs_broadcast_dims); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary( + multiply(broadcast(StripDim(lhs, lhs_collapsing_dim), + dot->shape(), lhs_broadcast_dims), + broadcast(StripDim(rhs, rhs_collapsing_dim), + dot->shape(), rhs_broadcast_dims))))); + return true; + } + + // If the lhs and rhs non-contracting dimensions are both one, strip each one, + // multiply and then reduce the collapsing dimension + if (lhs->shape().dimensions(lhs_kept_dim) == 1 && + rhs->shape().dimensions(rhs_kept_dim) == 1 && + lhs_kept_dim == rhs_kept_dim) { + auto new_lhs = StripDim(lhs, lhs_kept_dim); + auto new_rhs = StripDim(rhs, rhs_kept_dim); + const int64 reduce_dim = rhs_kept_dim < rhs_collapsing_dim + ? (rhs_collapsing_dim - 1) + : rhs_collapsing_dim; + TF_RETURN_IF_ERROR( + ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( + multiply(new_lhs, new_rhs), reduce_dim)))); + return true; + } + + // If the lhs non-contracting dimensions is one, strip the one, brodcast to + // the rhs shape, multiply and then reduce the collapsing dimension + if (lhs->shape().dimensions(lhs_kept_dim) == 1) { + auto new_lhs = broadcast(StripDim(lhs, lhs_kept_dim), rhs->shape(), + broadcast_dims(rhs_rank, rhs_kept_dim)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(new_lhs, rhs), + rhs_collapsing_dim)))); + return true; + } + + // If the rhs non-contracting dimensions is one, strip the one, brodcast to + // the lhs shape, multiply and then reduce the collapsing dimension + if (rhs->shape().dimensions(rhs_kept_dim) == 1) { + auto new_rhs = broadcast(StripDim(rhs, rhs_kept_dim), lhs->shape(), + broadcast_dims(lhs_rank, lhs_kept_dim)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, new_rhs), + lhs_collapsing_dim)))); + return true; + } + return false; } @@ -1242,6 +1575,9 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}. bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice; + HloDynamicSliceInstruction* dynamic_slice = + lhs_is_dynamic_slice ? Cast(lhs) + : Cast(rhs); // ctA: HloInstruction* left_operand = @@ -1259,8 +1595,6 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, dnums, dot->precision_config())); // Get pair {start, 0} or {0, start}. - HloInstruction* original_start_indices = - lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); // Position of start: int index_of_non_zero_start = lhs_is_dynamic_slice ? 1 - lhs_contracting_dimension @@ -1269,23 +1603,19 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( int index_of_zero_start = 1 - index_of_non_zero_start; // Slice out start and 0 components and reorder if necessary. - auto indices_type = original_start_indices->shape().element_type(); + auto indices_type = dynamic_slice->operand(1)->shape().element_type(); Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); HloInstruction* non_zero_start = - computation_->AddInstruction(HloInstruction::CreateSlice( - s_shape, original_start_indices, {index_of_non_zero_start}, - {index_of_non_zero_start + 1}, {1})); + dynamic_slice->mutable_operand(1 + index_of_non_zero_start); HloInstruction* zero_start = - computation_->AddInstruction(HloInstruction::CreateSlice( - s_shape, original_start_indices, {index_of_zero_start}, - {index_of_zero_start + 1}, {1})); - HloInstruction* new_start_indices = - lhs_is_dynamic_slice - ? computation_->AddInstruction(HloInstruction::CreateConcatenate( - d_shape, {non_zero_start, zero_start}, 0)) - : computation_->AddInstruction(HloInstruction::CreateConcatenate( - d_shape, {zero_start, non_zero_start}, 0)); + dynamic_slice->mutable_operand(1 + index_of_zero_start); + std::vector new_start_indices; + if (lhs_is_dynamic_slice) { + new_start_indices = {non_zero_start, zero_start}; + } else { + new_start_indices = {zero_start, non_zero_start}; + } // Build DynamicSlice(ctA x ctB). const int new_slice_m = lhs_is_dynamic_slice ? 1 : m; @@ -1301,26 +1631,145 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - - // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are - // rank 2 or below. - if ((dot->shape().element_type() != F32 && - dot->shape().element_type() != BF16) || - ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 || - ShapeUtil::Rank(dot->shape()) > 2) { + if (options_.is_layout_sensitive()) { return Status::OK(); } - // Replace a zero element dot with a broadcast of the constant 0. if (ShapeUtil::IsZeroElementArray(dot->shape()) || ShapeUtil::IsZeroElementArray(lhs->shape()) || ShapeUtil::IsZeroElementArray(rhs->shape())) { - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(dot->shape().element_type()))); return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } + // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are + // rank 2 or below. + if (dot->shape().element_type() != F32 && + dot->shape().element_type() != BF16) { + return Status::OK(); + } + + // If there are no contracting dimensions, a dot can be rewritten as + // mul(broadcast(transpose(x)),broadcast(transpose(y))) + if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_lhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + lhs, + AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().lhs_contracting_dimensions()))); + if (dot->shape().rank() != lhs->shape().rank()) { + std::vector lhs_broadcast_dims(lhs->shape().rank()); + absl::c_iota(lhs_broadcast_dims, 0); + new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + dot->shape(), new_lhs, lhs_broadcast_dims)); + } + TF_ASSIGN_OR_RETURN( + HloInstruction * new_rhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + rhs, + AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().rhs_contracting_dimensions()))); + if (dot->shape().rank() != rhs->shape().rank()) { + std::vector rhs_broadcast_dims( + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); + absl::c_iota(rhs_broadcast_dims, 0); + for (int64 i = lhs->shape().rank(); i < dot->shape().rank(); ++i) { + rhs_broadcast_dims.push_back(i); + } + new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + dot->shape(), new_rhs, rhs_broadcast_dims)); + } + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, + new_lhs, new_rhs)); + } + + // If the lhs or rhs have only batch and contracting dimensions, a dot can be + // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) + if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == + lhs->shape().rank()) || + (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() + + dot->dot_dimension_numbers().rhs_batch_dimensions_size() == + rhs->shape().rank())) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_lhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + lhs, + AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().lhs_contracting_dimensions()))); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_rhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + rhs, + AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().rhs_contracting_dimensions()))); + + int64 lhs_outer_dims = + lhs->shape().rank() - + (dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); + int64 rhs_outer_dims = + rhs->shape().rank() - + (dot->dot_dimension_numbers().rhs_batch_dimensions_size() + + dot->dot_dimension_numbers().rhs_contracting_dimensions_size()); + CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0); + if (rhs_outer_dims > 0) { + std::vector lhs_broadcast_dims( + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); + absl::c_iota(lhs_broadcast_dims, 0); + lhs_broadcast_dims.resize(lhs->shape().rank()); + std::iota(lhs_broadcast_dims.begin() + + dot->dot_dimension_numbers().lhs_batch_dimensions_size(), + lhs_broadcast_dims.end(), + dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + rhs_outer_dims); + new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + new_rhs->shape(), new_lhs, lhs_broadcast_dims)); + } else if (lhs_outer_dims > 0) { + std::vector rhs_broadcast_dims( + dot->dot_dimension_numbers().rhs_batch_dimensions_size()); + absl::c_iota(rhs_broadcast_dims, 0); + rhs_broadcast_dims.resize(rhs->shape().rank()); + std::iota(rhs_broadcast_dims.begin() + + dot->dot_dimension_numbers().rhs_batch_dimensions_size(), + rhs_broadcast_dims.end(), + dot->dot_dimension_numbers().rhs_batch_dimensions_size() + + lhs_outer_dims); + new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + new_lhs->shape(), new_rhs, rhs_broadcast_dims)); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, + MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs)); + std::vector reduce_dims( + dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); + new_dot = AsType(new_dot, F32); + const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims); + absl::c_iota( + reduce_dims, + outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); + new_dot = AddReduce(new_dot, reduce_dims); + new_dot = AsType(new_dot, dot->shape().element_type()); + return ReplaceInstruction(dot, new_dot); + } + + if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 || + dot->shape().rank() > 2) { + if (options_.enable_dot_strength_reduction() && + !options_.is_layout_sensitive()) { + TF_RETURN_IF_ERROR(HandleDotStrengthReduction(dot).status()); + } + return Status::OK(); + } + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, OptimizeDotOfConcat(dot)); if (dot_of_concat_optimized) { @@ -1350,7 +1799,11 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { } // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). - if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { + if (dot->dot_dimension_numbers().lhs_batch_dimensions_size() == 0 && + dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 1 && + dot->dot_dimension_numbers().lhs_contracting_dimensions(0) == 1 && + dot->dot_dimension_numbers().rhs_contracting_dimensions(0) == 0 && + lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { DotDimensionNumbers dot_dimension_numbers; dot_dimension_numbers.add_lhs_contracting_dimensions(1); dot_dimension_numbers.add_rhs_contracting_dimensions(0); @@ -1549,7 +2002,7 @@ bool OutputIsPermutationOfOperandElements(HloInstruction* instruction, case HloOpcode::kTranspose: return true; case HloOpcode::kSort: - return (!ShapeUtil::IsTuple(instruction->shape())); + return (!instruction->shape().IsTuple()); default: return false; } @@ -1595,8 +2048,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // A degenerate broadcast that has the same input and output rank can be // converted into a transpose. - if (ShapeUtil::Rank(broadcast->shape()) == - ShapeUtil::Rank(operand->shape()) && + if (broadcast->shape().rank() == operand->shape().rank() && ShapeUtil::ElementsIn(broadcast->shape()) == ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> transpose(X) where " @@ -1751,7 +2203,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { if (HasInteriorPadding(pad->padding_config())) { PaddingConfig padding_config = pad->padding_config(); bool cleared_interior_padding = false; - for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + for (int64 i = 0; i < pad->shape().rank(); ++i) { if (padding_config.dimensions(i).interior_padding() > 0 && pad->operand(0)->shape().dimensions(i) == 1) { cleared_interior_padding = true; @@ -2002,14 +2454,151 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( return changed; } +namespace { +template +std::unique_ptr TryRemainderToAnd(HloInstruction* remainder, + HloComputation* computation) { + HloInstruction *a, *b, *c; + CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); + + if (ShapeUtil::ElementIsIntegral(remainder->shape()) && + !Match(b, m::ConstantEffectiveScalar(&c)) && + !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) { + return nullptr; + } + + if (ShapeUtil::ElementIsSigned(remainder->shape())) { + int64 b_value = c->literal().GetFirstElement(); + if (b_value > 0 && IsPowerOfTwo(static_cast(b_value))) { + // Handle negative dividends by negating the result of the division. + HloInstruction* zero_like_a = BroadcastZeros( + computation, a->shape().element_type(), a->shape().dimensions()); + + auto* dividend_is_negative = + computation->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a, + zero_like_a)); + + auto* negated_dividend = computation->AddInstruction( + HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); + + auto* abs_dividend = + computation->AddInstruction(HloInstruction::CreateTernary( + a->shape(), HloOpcode::kSelect, dividend_is_negative, + negated_dividend, a)); + + auto* mask_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(b_value - 1))); + if (!ShapeUtil::IsScalar(b->shape())) { + mask_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), mask_amount, {})); + } + + auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary( + remainder->shape(), HloOpcode::kAnd, abs_dividend, mask_amount)); + + auto* neqated_quotient = + computation->AddInstruction(HloInstruction::CreateUnary( + quotient->shape(), HloOpcode::kNegate, quotient)); + + return HloInstruction::CreateTernary( + remainder->shape(), HloOpcode::kSelect, dividend_is_negative, + neqated_quotient, quotient); + } + } else { + uint64 b_value = c->literal().GetFirstElement(); + if (IsPowerOfTwo(b_value)) { + HloInstruction* mask_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(b_value - 1))); + if (!ShapeUtil::IsScalar(b->shape())) { + mask_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), mask_amount, {})); + } + return HloInstruction::CreateBinary(remainder->shape(), HloOpcode::kAnd, + a, mask_amount); + } + } + return nullptr; +} +} // namespace + +Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) { + HloInstruction *a, *b; + CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); + + // A % B => A & (B - 1) if B is a power of 2. + switch (remainder->shape().element_type()) { + case S8: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case S16: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case S32: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case S64: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U8: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U16: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U32: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U64: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + default: + break; + } + + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { auto operand = reshape->mutable_operand(0); // Reshape directly to empty constant if the shape contains zero-element // dimension. if (ShapeUtil::IsZeroElementArray(reshape->shape())) { + // If the instruction doesn't have a layout, use a default layout for + // the literal result. + Shape reshaped_shape = reshape->shape(); + if (!LayoutUtil::HasLayout(reshaped_shape)) { + LayoutUtil::SetToDefaultLayout(&reshaped_shape); + } auto empty_constant = HloInstruction::CreateConstant( - Literal::CreateFromShape(reshape->shape())); + Literal::CreateFromShape(reshaped_shape)); return ReplaceWithNewInstruction(reshape, std::move(empty_constant)); } @@ -2026,6 +2615,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { reshape, HloInstruction::CreateReshape(reshape->shape(), operand->mutable_operand(0))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { *operand->mutable_shape() = reshape->shape(); return ReplaceInstruction(reshape, operand); @@ -2057,12 +2647,10 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { } // Make this a bitcast if possible. - if (options_.is_layout_sensitive() && - ReshapeOrCopyIsBitcast(reshape, options_.valid_bitcast_callback())) { - ReplaceWithBitcast(reshape); - return Status::OK(); + if (HloInstruction* bitcast_operand = + BitcastingOperandOfReshapeOrCopyChain(reshape, options_)) { + ReplaceWithBitcast(reshape, bitcast_operand); } - return Status::OK(); } @@ -2072,8 +2660,7 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { auto dim_is_one = [&](int64 i) -> bool { return reverse->shape().dimensions(i) == 1; }; - if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(), - dim_is_one)) { + if (absl::c_all_of(reverse->dimensions(), dim_is_one)) { return ReplaceInstruction(reverse, reverse->mutable_operand(0)); } return Status::OK(); @@ -2106,11 +2693,11 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( int64 start = slice->slice_starts(i); int64 low = padding_config.dimensions(i).edge_padding_low(); int64 data = pad->operand(0)->shape().dimensions(i); - if (start >= low && start < low + data) { - return false; + if (start < low || start >= low + data) { + return true; } } - return true; + return false; }(); if (in_padding) { @@ -2138,7 +2725,7 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) { VLOG(10) << "Trying to simplify scalar slice of concat"; // Only do this for R1, there's no chance of this being useful otherwise. - if (ShapeUtil::Rank(slice->shape()) != 1) { + if (slice->shape().rank() != 1) { VLOG(10) << "Not folding, slice is not rank 1"; return false; } @@ -2188,7 +2775,7 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( return false; } HloInstruction* new_slice_operand = reshape->mutable_operand(0); - int64 slice_rank = ShapeUtil::Rank(slice->shape()); + int64 slice_rank = slice->shape().rank(); std::vector sliced_dims; for (int64 i = 0; i < slice_rank; ++i) { if (slice->slice_starts(i) != 0 || @@ -2200,7 +2787,7 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( if (sliced_dims.size() == 1 && sliced_dims[0] == 0 && slice->slice_starts(0) == 0) { const Shape& new_slice_shape = new_slice_operand->shape(); - const int64 rank = ShapeUtil::Rank(new_slice_shape); + const int64 rank = new_slice_shape.rank(); std::vector new_slice_starts(rank, 0); std::vector new_slice_stides(rank, 1); std::vector new_slice_limits(new_slice_shape.dimensions().begin(), @@ -2297,28 +2884,71 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { - // TODO(b/112040122): Most of those optimizations can be done for multi-output - // reduces. - if (ShapeUtil::IsTuple(reduce->shape())) { - return Status::OK(); - } +Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { + HloReduceInstruction* reduce = Cast(hlo); + bool multi_output_reduce = reduce->shape().IsTuple(); + + // For tuple reduce, we require all reduce shapes to be the same, up to the + // element types, so we can just the first operand and the first result as a + // representative. + auto arg = reduce->inputs()[0]; + auto init_value = reduce->init_values()[0]; + const Shape& reduce_result_shape = + multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape(); - auto arg = reduce->mutable_operand(0); - auto init_value = reduce->mutable_operand(1); absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); if (ShapeUtil::IsZeroElementArray(arg->shape()) || - ShapeUtil::IsZeroElementArray(reduce->shape())) { - return ReplaceWithNewInstruction( - reduce, - HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); + ShapeUtil::IsZeroElementArray(reduce_result_shape)) { + if (multi_output_reduce) { + std::vector broadcast_inits; + int64 inputs = reduce->input_count(); + for (int64 i = 0; i < inputs; ++i) { + broadcast_inits.push_back(computation_->AddInstruction( + HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i), + reduce->init_values()[i], {}))); + } + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateTuple(broadcast_inits)); + } else { + return ReplaceWithNewInstruction( + reduce, + HloInstruction::CreateBroadcast(reduce_result_shape, init_value, {})); + } + } + + // If the reduction results in the same number of elements, then the only + // possible side effect would be a reshape. Since the init_value is an + // identity of the reduction function, we can therefore replace the reduce + // with a simple reshape, ignoring the reduction function completely. + if (ShapeUtil::ElementsIn(reduce_result_shape) == + ShapeUtil::ElementsIn(arg->shape())) { + if (multi_output_reduce) { + std::vector reshaped_args; + int64 inputs = reduce->input_count(); + for (int64 i = 0; i < inputs; ++i) { + reshaped_args.push_back( + computation_->AddInstruction(HloInstruction::CreateReshape( + reduce->shape().tuple_shapes(i), reduce->inputs()[i]))); + } + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateTuple(reshaped_args)); + } else { + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateReshape(reduce_result_shape, arg)); + } + } + + // TODO(b/112040122): Most of those optimizations below can be done for + // multi-output reduces. + if (multi_output_reduce) { + return Status::OK(); } // A Transpose feeding a reduce can simply permute the reduction dimensions // field if the output of the reduce is a vector or scalar. Higher ranked // result may require a transpose of the output. - if (ShapeUtil::Rank(reduce->shape()) <= 1 && + if (reduce_result_shape.rank() <= 1 && arg->opcode() == HloOpcode::kTranspose) { auto transpose_dimensions = arg->dimensions(); std::vector new_reduce_dimensions; @@ -2327,20 +2957,10 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { } return ReplaceWithNewInstruction( reduce, HloInstruction::CreateReduce( - reduce->shape(), arg->mutable_operand(0), init_value, + reduce_result_shape, arg->mutable_operand(0), init_value, new_reduce_dimensions, function)); } - // If the reduction results in the same number of elements, then the only - // possible side effect would be a reshape. Since the init_value is an - // identity of the reduction function, we can therefore replace the reduce - // with a simple reshape, ignoring the reduction function completely. - if (ShapeUtil::ElementsIn(reduce->shape()) == - ShapeUtil::ElementsIn(arg->shape())) { - return ReplaceWithNewInstruction( - reduce, HloInstruction::CreateReshape(reduce->shape(), arg)); - } - // If a reduce feeds a reduce with the same computation and initial value, // they can be combined into a single reduce. if (arg->opcode() == HloOpcode::kReduce && @@ -2349,9 +2969,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { // Create a new reduce with the combined reduction dimensions of both // reduces. std::vector arg_dims = arg->dimensions(); - std::sort(arg_dims.begin(), arg_dims.end()); + absl::c_sort(arg_dims); std::vector reduce_dims = reduce->dimensions(); - std::sort(reduce_dims.begin(), reduce_dims.end()); + absl::c_sort(reduce_dims); // Transform reduce_dims to the same rank as the operand of the operand. for (int64 arg_dim : arg_dims) { for (int64& dim : reduce_dims) { @@ -2366,9 +2986,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(), reduce_dims.end(), std::back_inserter(new_dimensions)); return ReplaceWithNewInstruction( - reduce, - HloInstruction::CreateReduce(reduce->shape(), arg->mutable_operand(0), - init_value, new_dimensions, function)); + reduce, HloInstruction::CreateReduce( + reduce_result_shape, arg->mutable_operand(0), init_value, + new_dimensions, function)); } // A reshape that collapses multiple dimensions into a dimension being @@ -2378,8 +2998,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { std::vector> unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(), arg->shape()); - std::vector arg_dim_in_output(ShapeUtil::Rank(arg->shape()), true); - std::vector arg_dim_unmodified(ShapeUtil::Rank(arg->shape()), false); + std::vector arg_dim_in_output(arg->shape().rank(), true); + std::vector arg_dim_unmodified(arg->shape().rank(), false); for (auto dim : dimensions) { arg_dim_in_output[dim] = false; } @@ -2397,21 +3017,21 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { } if (can_move_reshape_into_reduce) { changed_ = true; - std::unordered_set dimensions_not_to_reduce; + absl::flat_hash_set dimensions_not_to_reduce; for (auto dim_pair : unmodified_dims) { if (arg_dim_in_output[dim_pair.second]) { dimensions_not_to_reduce.insert(dim_pair.first); } } std::vector new_reduce_dimensions; - for (int64 i = 0; i < ShapeUtil::Rank(arg->operand(0)->shape()); ++i) { - if (dimensions_not_to_reduce.count(i) == 0) { + for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) { + if (!dimensions_not_to_reduce.contains(i)) { new_reduce_dimensions.push_back(i); } } return ReplaceWithNewInstruction( reduce, HloInstruction::CreateReduce( - reduce->shape(), arg->mutable_operand(0), init_value, + reduce_result_shape, arg->mutable_operand(0), init_value, new_reduce_dimensions, function)); } } @@ -2426,11 +3046,11 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { HloInstruction* old_reduce = nullptr; for (HloInstruction* operand : arg->operands()) { HloInstruction* new_reduce = computation_->AddInstruction( - HloInstruction::CreateReduce(reduce->shape(), operand, init_value, + HloInstruction::CreateReduce(reduce_result_shape, operand, init_value, reduce->dimensions(), function)); if (old_reduce != nullptr) { new_reduce = computation_->AddInstruction(HloInstruction::CreateMap( - reduce->shape(), {old_reduce, new_reduce}, function)); + reduce_result_shape, {old_reduce, new_reduce}, function)); } old_reduce = new_reduce; } @@ -2459,6 +3079,55 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( function)); } + if (options_.enable_window_reduce_to_reduce_replacement()) { + // A reduce window can be expressed as a reduce and a reshape if all + // dimensions either have a window size of one or the entire dimension. If + // there is no stride, dilation, or padding, this is as easy as checking the + // size of the output shape and window dimension. + // + // The reshape is a bitcast since it adds one-sized dimensions. Often these + // ones are immediately removed as well with another reshape. The + // implementation of reduce tends to be slightly more efficient at reducing + // entire dimensions compared to reduce window. + auto effective_reduce_dims = [&] { + if (window_util::HasStride(window) || window_util::HasDilation(window) || + window_util::HasPadding(window)) { + return absl::InlinedVector{}; + } + absl::InlinedVector reduce_dims; + for (int64 i = 0; i < window.dimensions_size(); ++i) { + if (window.dimensions(i).size() == 1) { + continue; + } else if (reduce_window->shape().dimensions(i) == 1) { + reduce_dims.push_back(i); + } else { + return absl::InlinedVector{}; + } + } + return reduce_dims; + }(); + + // If a reduce window can be expressed as a reduce, do so and reshape the + // output. + if (!effective_reduce_dims.empty()) { + Shape reduce_shape = ShapeUtil::FilterDimensions( + [&](int64 dim) { + return !absl::c_linear_search(effective_reduce_dims, dim); + }, + reduce_window->shape()); + HloInstruction* reduce = + computation_->AddInstruction(HloInstruction::CreateReduce( + /*shape=*/reduce_shape, + /*operand=*/operand, + /*init_value=*/reduce_window->mutable_operand(1), + /*dimensions_to_reduce=*/effective_reduce_dims, + /*reduce_computation=*/function)); + return ReplaceWithNewInstruction( + reduce_window, + HloInstruction::CreateReshape(reduce_window->shape(), reduce)); + } + } + // This optimization folds a pad op into reduce_window. HloInstruction* pad; const HloInstruction* convert = nullptr; @@ -2594,7 +3263,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // Carry out the folding of the pad into reduce_window. VLOG(10) << "Folding pad into reduce-window."; Window new_window = window; - const int64 rank = ShapeUtil::Rank(reduce_window->shape()); + const int64 rank = reduce_window->shape().rank(); TF_RET_CHECK(pad_config.dimensions_size() == rank); TF_RET_CHECK(window.dimensions_size() == rank); for (int64 i = 0; i < rank; ++i) { @@ -2643,110 +3312,24 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { return ReplaceWithNewInstruction( sort, HloInstruction::CreateTuple(sort->operands())); } - if (!options_.enable_permutation_sort_replacement()) { - return Status::OK(); - } - // Check if we are sorting a permutation. In that case, we know that the keys - // will be sorted to the identity permutation, and we can represent the - // changes to the 'values' parameter as a scatter. - if (sort->operand_count() == 2 && - operand->opcode() == HloOpcode::kGetTupleElement) { - const HloInstruction* other_sort = operand->operand(0); - // Check whether the 'values' parameter is the result of another sort with - // the same sort dimension. - if (other_sort->opcode() == HloOpcode::kSort && - other_sort->operand_count() >= 2 && - other_sort->dimensions(0) == dimension_to_sort && - other_sort->operand(operand->tuple_index())->opcode() == - HloOpcode::kIota) { - auto* iota = - Cast(other_sort->operand(operand->tuple_index())); - // The sort operand needs to be an integral iota, and the iota dimension - // needs to be the dimension that was sorted. - if (iota->iota_dimension() == dimension_to_sort && - ShapeUtil::ElementIsIntegral(iota->shape())) { - // We use the following construction method for a Scatter that applies - // the permutation from 'keys' to the 'values' parameter. - // - Take the "keys" parameter of the second sort and reshape it to have - // another "1" dimension at the end. - // - Concatenate it with iotas of the same extended shape with all - // different iota_dimensions except the dimension_to_sort in the order - // of iota_dimensions/dimension_to_sort, so e.g. with rank 3 and - // dimension_to_sort = 1, we would have concatenate of (iota with - // iota_dimension=0, keys, iota with iota_dimension = 2) - // - Use this as the indices parameter of scatter, and set updates - // of the scatter to be a reshaped 'values' parameter of sort (adding - // 'rank' many 1 dimensions at the end). - int64 rank = ShapeUtil::Rank(operand->shape()); - Shape extended_shape = operand->shape(); - extended_shape.add_dimensions(1); - extended_shape.mutable_layout()->add_minor_to_major(rank); - auto reshaped_permutation = computation_->AddInstruction( - HloInstruction::CreateReshape(extended_shape, operand)); - std::vector concat_operands; - for (int64 i = 0; i < rank; ++i) { - if (i == dimension_to_sort) { - concat_operands.push_back(reshaped_permutation); - } else { - concat_operands.push_back(computation_->AddInstruction( - HloInstruction::CreateIota(extended_shape, i))); - } - } - Shape concat_shape = operand->shape(); - concat_shape.add_dimensions(rank); - concat_shape.mutable_layout()->add_minor_to_major(rank); - auto scatter_indices = - rank > 1 ? computation_->AddInstruction( - HloInstruction::CreateConcatenate( - concat_shape, concat_operands, rank)) - : reshaped_permutation; - - // We don't care about the operand, it will be completely overridden by - // the updates. - auto scatter_operand = computation_->AddInstruction( - HloInstruction::CreateIota(sort->operand(1)->shape(), 0)); - - // Construct the updates operand of scatter. - Shape update_shape = sort->operand(1)->shape(); - for (int64 i = 0; i < rank; ++i) { - update_shape.add_dimensions(1); - update_shape.mutable_layout()->add_minor_to_major(rank + i); - } - auto scatter_updates = - computation_->AddInstruction(HloInstruction::CreateReshape( - update_shape, sort->mutable_operand(1))); - - // Construct the updates computation, which simply replaces the operand - // values with the update values. - HloComputation::Builder b("update_replace_computation"); - Shape scalar_shape = ShapeUtil::MakeShape(S32, {}); - b.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "scalar_lhs")); - auto scalar_rhs = b.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "scalar_rhs")); - auto update_replace_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_rhs)); - - ScatterDimensionNumbers dim_numbers; - dim_numbers.set_index_vector_dim(rank); - for (int64 i = 0; i < rank; ++i) { - dim_numbers.add_update_window_dims(rank + i); - dim_numbers.add_scatter_dims_to_operand_dims(i); - } - auto scatter = - computation_->AddInstruction(HloInstruction::CreateScatter( - sort->operand(1)->shape(), scatter_operand, scatter_indices, - scatter_updates, update_replace_computation, dim_numbers)); - return ReplaceWithNewInstruction( - sort, HloInstruction::CreateTuple( - {computation_->AddInstruction(HloInstruction::CreateIota( - operand->shape(), dimension_to_sort)), - scatter})); - } + return Status::OK(); +} + +namespace { +bool OnlyPermutesMoreThanOneDegenerateDim(const Shape& shape, + absl::Span perm) { + std::vector new_permutation; + int64 degenerate_count = 0; + for (int64 i = 0; i < perm.size(); ++i) { + if (shape.dimensions(i) != 1) { + new_permutation.push_back(perm[i]); + } else { + ++degenerate_count; } } - return Status::OK(); + return degenerate_count > 1 && absl::c_is_sorted(new_permutation); } +} // namespace Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); @@ -2764,6 +3347,15 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { transpose->dimensions()))); } + // Replace transpose with a reshape if more than one degenerate method is + // permuted. + if (OnlyPermutesMoreThanOneDegenerateDim(transpose->shape(), + transpose->dimensions())) { + return ReplaceWithNewInstruction( + transpose, HloInstruction::CreateReshape( + transpose->shape(), transpose->mutable_operand(0))); + } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { *operand->mutable_shape() = transpose->shape(); return ReplaceInstruction(transpose, operand); @@ -3011,15 +3603,6 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout( convolution_shape.element_type(), {conv_width, output_channels}); - // We cannot insert bitcasts if the layouts will not be compatible. - // TODO(b/33178038): Consider inserting a transpose if a bitcast would be - // invalid. - if (!options_.valid_bitcast_callback()(input_shape, new_input_shape) || - !options_.valid_bitcast_callback()(filter_shape, new_filter_shape) || - !options_.valid_bitcast_callback()(dot_output_shape, convolution_shape)) { - return false; - } - auto new_lhs = add_bitcast(new_input_shape, lhs); auto new_rhs = add_bitcast(new_filter_shape, rhs); DotDimensionNumbers dot_dimension_numbers; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index d2775b9fafa7e4c625f5d181114e80e7369f9c78..df5a8c2ec141458a95fafb76b1e99e4b04a61b28 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -25,21 +25,25 @@ namespace xla { class AlgebraicSimplifierOptions { public: - // Given shapes 'from_shape' and 'to_shape', determines if it is valid to - // bitcast from 'from_shape' to 'to_shape' after considering platform - // dependent effects on layout like alignment restrictions. Precondition: the - // two shapes have layouts, the same number of elements and - // ShapeUtil::ReshapeIsBitcast returns true. - using ValidBitcastCallback = + AlgebraicSimplifierOptions() {} + // Platform dependent callback to determine if a reshape `from_shape` to + // `to_shape` is a bitcast. + using ReshapeIsBitcastCallback = std::function; - explicit AlgebraicSimplifierOptions( - ValidBitcastCallback valid_bitcast_callback) - : valid_bitcast_callback_(std::move(valid_bitcast_callback)) {} - // If valid_bitcast_callback returns true, then the pass will replace reshapes - // and transposes with bitcasts. - const ValidBitcastCallback& valid_bitcast_callback() const { - return valid_bitcast_callback_; + ReshapeIsBitcastCallback reshape_is_bitcast_callback) + : reshape_is_bitcast_callback_(std::move(reshape_is_bitcast_callback)) {} + + // Use the platform specific callback if set. It is not sensible to return + // true here if the options are not layout sensitive. + bool ReshapeIsBitcast(const Shape& from_shape, const Shape& to_shape) const { + if (!is_layout_sensitive_) { + return false; + } + if (!reshape_is_bitcast_callback_) { + return ShapeUtil::ReshapeIsBitcast(from_shape, to_shape); + } + return reshape_is_bitcast_callback_(from_shape, to_shape); } // If is_layout_sensitive is true, then the simplifier preserves layout during @@ -47,12 +51,14 @@ class AlgebraicSimplifierOptions { void set_is_layout_sensitive(bool is_layout_sensitive) { is_layout_sensitive_ = is_layout_sensitive; } + bool is_layout_sensitive() const { return is_layout_sensitive_; } // Enable dot simplification on platforms where it is profitable. void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) { enable_dot_strength_reduction_ = enable_dot_strength_reduction; } + bool enable_dot_strength_reduction() const { return enable_dot_strength_reduction_; } @@ -65,22 +71,24 @@ class AlgebraicSimplifierOptions { return enable_conv_simplification_; } - // If enable_permutation_sort_replacement is true, a sort op that is known to - // sort a permutation will be replaced with a scatter op. - void set_enable_permutation_sort_replacement( - bool enable_permutation_sort_replacement) { - enable_permutation_sort_replacement_ = enable_permutation_sort_replacement; + // If enable_window_reduce_replacement is true, the kReduceWindow instruction + // can be optimized by replacement with simpler operations. + void set_enable_window_reduce_to_reduce_replacement( + bool enable_window_reduce_to_reduce_replacement) { + enable_window_reduce_to_reduce_replacement_ = + enable_window_reduce_to_reduce_replacement; } - bool enable_permutation_sort_replacement() const { - return enable_permutation_sort_replacement_; + + bool enable_window_reduce_to_reduce_replacement() const { + return enable_window_reduce_to_reduce_replacement_; } private: - ValidBitcastCallback valid_bitcast_callback_; + ReshapeIsBitcastCallback reshape_is_bitcast_callback_; bool is_layout_sensitive_{false}; bool enable_dot_strength_reduction_{true}; bool enable_conv_simplification_{true}; - bool enable_permutation_sort_replacement_{false}; + bool enable_window_reduce_to_reduce_replacement_{true}; }; // A pass which performs algebraic simplifications. diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 14ce519b6a0fd221070006d336d23bddeb6cd621..06f6206a3b3d0007dc4b6a91395babb510bf023e 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -46,17 +47,9 @@ namespace { using ::testing::ElementsAre; namespace m = match; -AlgebraicSimplifierOptions::ValidBitcastCallback bitcasting_callback() { - return [](const Shape&, const Shape&) { return true; }; -} - -AlgebraicSimplifierOptions::ValidBitcastCallback non_bitcasting_callback() { - return [](const Shape&, const Shape&) { return false; }; -} - class AlgebraicSimplifierTest : public HloTestBase { protected: - AlgebraicSimplifierOptions default_options_{non_bitcasting_callback()}; + AlgebraicSimplifierOptions default_options_; }; // Test that A + 0 is simplified to A @@ -202,6 +195,86 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) { m::Broadcast(m::ConstantScalar(0.125))))); } +TEST_F(AlgebraicSimplifierTest, UnsignedDivideByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = u32[4] parameter(0) + c = u32[] constant(8) + b = u32[4] broadcast(c), dimensions={} + ROOT d = u32[4] divide(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::ShiftRightLogical( + m::Parameter(0), m::Broadcast(m::ConstantScalar(3))))); +} + +TEST_F(AlgebraicSimplifierTest, SignedDivideByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = s32[4] parameter(0) + c = s32[] constant(8) + b = s32[4] broadcast(c), dimensions={} + ROOT d = s32[4] divide(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto match_dividend_is_negative = + m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0))); + auto match_abs = m::Select(match_dividend_is_negative, + m::Negate(m::Parameter(0)), m::Parameter(0)); + auto match_shift = + m::ShiftRightLogical(match_abs, m::Broadcast(m::ConstantScalar(3))); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Select(match_dividend_is_negative, + m::Negate(match_shift), match_shift))); +} + +TEST_F(AlgebraicSimplifierTest, UnsignedRemainderByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = u32[4] parameter(0) + c = u32[] constant(8) + b = u32[4] broadcast(c), dimensions={} + ROOT r = u32[4] remainder(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::AndAnyOrder(m::Parameter(0), + m::Broadcast(m::ConstantScalar(7))))); +} + +TEST_F(AlgebraicSimplifierTest, SignedRemainderByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = s32[4] parameter(0) + c = s32[] constant(8) + b = s32[4] broadcast(c), dimensions={} + ROOT r = s32[4] remainder(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto match_dividend_is_negative = + m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0))); + auto match_abs = m::Select(match_dividend_is_negative, + m::Negate(m::Parameter(0)), m::Parameter(0)); + auto match_and = + m::AndAnyOrder(match_abs, m::Broadcast(m::ConstantScalar(7))); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Select(match_dividend_is_negative, + m::Negate(match_and), match_and))); +} + // Test that A * 0 is simplified to 0 TEST_F(AlgebraicSimplifierTest, MulZero) { auto m = CreateNewVerifiedModule(); @@ -1273,7 +1346,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { // Create add computation. builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m->AddEntryComputation(builder.Build()); HloPassFix simplifier(default_options_); EXPECT_THAT(m->entry_computation()->root_instruction(), @@ -1283,6 +1356,51 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { GmockMatch(m::Broadcast(m::Constant()))); } +TEST_F(AlgebraicSimplifierTest, ReduceWindowIsReduceAndReshape) { + auto m = CreateNewVerifiedModule(); + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "param")); + Window window; + for (int64 i = 0; i < 4; ++i) { + WindowDimension* dim = window.add_dimensions(); + // Makes 1x2x3x1 window. + dim->set_size((i % 3) + 1); + dim->set_stride(1); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = m->AddEmbeddedComputation(builder.Build()); + } + builder.AddInstruction(HloInstruction::CreateReduceWindow( + ShapeUtil::MakeShape(F32, {1, 1, 1, 4}), param, + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), + window, add_computation)); + m->AddEntryComputation(builder.Build()); + HloPassFix simplifier(default_options_); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant()))); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Reshape(m::Reduce(m::Parameter(0), m::Constant())))); +} + TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -1419,23 +1537,77 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { EXPECT_THAT(computation->root_instruction(), param0); } -TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { +TEST_F(AlgebraicSimplifierTest, CopyOfReshapeOfCopyEqualsBitcast) { auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), "param")); - *param->mutable_shape()->mutable_layout() = - LayoutUtil::MakeLayout({0, 1, 2, 3}); + 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}), + "param")); HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), HloOpcode::kCopy, param)); - *copy->mutable_shape()->mutable_layout() = - LayoutUtil::MakeLayout({1, 2, 0, 3}); + ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}), + HloOpcode::kCopy, param)); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {0, 1}), copy)); + builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}), + HloOpcode::kCopy, reshape)); + auto computation = m->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Reshape(m::Copy(m::Parameter(0)))))); + + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + // Verify that the copy of reshape of copy is replaced. + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); +} + +TEST_F(AlgebraicSimplifierTest, ReshapeOfCopyEqualsBitcast) { + auto m = CreateNewVerifiedModule(); + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}), + "param")); + HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}), + HloOpcode::kCopy, param)); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}), copy)); + + auto computation = m->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Copy(m::Parameter(0))))); + + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + // Verify that the copy of reshape of copy is replaced. + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); +} + +TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { + auto m = CreateNewVerifiedModule(); + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}), + "param")); + builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {1, 2, 0, 3}), + HloOpcode::kCopy, param)); auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier1(options); ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie()); @@ -1443,10 +1615,10 @@ TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifierOptions options2(bitcasting_callback()); + AlgebraicSimplifierOptions options2; options2.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier2(options2); - ASSERT_TRUE(simplifier2.Run(m.get()).ValueOrDie()); + EXPECT_TRUE(simplifier2.Run(m.get()).ValueOrDie()); // Verify that the copy is replaced. EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Bitcast(m::Parameter(0)))); @@ -1699,7 +1871,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); @@ -1729,7 +1901,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1759,7 +1931,8 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Reshape(m::Parameter(0)))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); @@ -1790,8 +1963,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier( - (AlgebraicSimplifierOptions(bitcasting_callback()))); + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{}); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that reshape(transpose(rng)) is replace by a single rng of the @@ -1842,7 +2014,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { m::Op().Is(dimensions_wrong_reshape), m::Op().Is(layout_wrong_reshape)))); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); simplifier.Run(m.get()).ValueOrDie(); @@ -1872,8 +2044,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add)); - AlgebraicSimplifier simplifier( - (AlgebraicSimplifierOptions(bitcasting_callback()))); + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{}); m->AddEntryComputation(builder.Build()); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } @@ -1897,8 +2068,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0, 1})); - AlgebraicSimplifier simplifier( - (AlgebraicSimplifierOptions(bitcasting_callback()))); + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{}); m->AddEntryComputation(builder.Build()); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } @@ -1923,7 +2093,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Transpose(m::Parameter(0)))); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1953,7 +2123,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Transpose(m::Parameter(0)))); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -2010,7 +2180,7 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Copy(m::Copy(m::Parameter(0))))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -2047,6 +2217,26 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { computation->root_instruction()->dimensions()); } +TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[10] parameter(0) + reshaped = f32[1,1,10] reshape(f32[10] param) + transposed = f32[10,1,1] transpose(f32[1,1,10] reshaped), dimensions={2,1,0} + ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloPassFix simplifier(default_options_); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Parameter())); +} + // Test merging reshape and broadcast. TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { auto m = CreateNewVerifiedModule(); @@ -2558,93 +2748,23 @@ TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { auto builder = HloComputation::Builder(TestName()); + auto module = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {1}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); - auto module = CreateNewVerifiedModule(); + TF_ASSERT_OK(MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, &builder, + module.get()) + .status()); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), keys); } -TEST_F(AlgebraicSimplifierTest, ReplacePermutationSortWithScatter) { - const char* hlo_string = R"( - HloModule permutation_sort - - ENTRY sort_computation { - keys = f32[64,8732]{1,0} parameter(0) - values = s32[64,8732]{1,0} iota(), iota_dimension=1 - sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} - gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 - ROOT sort2 = (s32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(gte, values), dimensions={1} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - AlgebraicSimplifierOptions options(non_bitcasting_callback()); - options.set_enable_permutation_sort_replacement(true); - AlgebraicSimplifier simplifier(options); - EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, - GmockMatch(m::Tuple( - m::Iota(), - m::Scatter(m::Iota(), m::Concatenate(m::Iota(), m::Reshape()), - m::Reshape())))); -} - -TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortIfNonIntegral) { - // Same as ReplacePermutationSortWithScatter except that the iota has F32 - // type. - const char* hlo_string = R"( - HloModule permutation_sort - - ENTRY sort_computation { - keys = f32[64,8732]{1,0} parameter(0) - values = f32[64,8732]{1,0} iota(), iota_dimension=1 - sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values), dimensions={1} - gte = f32[64,8732]{1,0} get-tuple-element(sort), index=1 - ROOT sort2 = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(gte, values), dimensions={1} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - AlgebraicSimplifierOptions options(non_bitcasting_callback()); - options.set_enable_permutation_sort_replacement(true); - AlgebraicSimplifier simplifier(options); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); -} - -TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortWrongDimensions) { - // Same as ReplacePermutationSortWithScatter except that the sort dimensions - // don't match. - const char* hlo_string = R"( - HloModule permutation_sort - - ENTRY sort_computation { - keys = f32[64,8732]{1,0} parameter(0) - values = s32[64,8732]{1,0} iota(), iota_dimension=1 - sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} - gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 - ROOT sort2 = (s32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(gte, values), dimensions={0} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - AlgebraicSimplifierOptions options(non_bitcasting_callback()); - options.set_enable_permutation_sort_replacement(true); - AlgebraicSimplifier simplifier(options); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); -} - TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { auto builder = HloComputation::Builder(TestName()); + auto module = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0}); Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0}); @@ -2654,10 +2774,11 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { HloInstruction::CreateParameter(1, values_shape, "values0")); auto values1 = builder.AddInstruction( HloInstruction::CreateParameter(2, values_shape, "values1")); - builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0, - keys, {values0, values1})); - auto module = CreateNewVerifiedModule(); + TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape( + {keys_shape, values_shape, values_shape}), + {keys, values0, values1}, 0, /*is_stable=*/false, + &builder, module.get()) + .status()); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -2879,7 +3000,7 @@ class ConvInputPaddingTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface {}; -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( ConvInputPaddingTestCases, ConvInputPaddingTest, ::testing::ValuesIn(std::vector{ // Merge this edge padding into the conv. @@ -2950,11 +3071,11 @@ TEST_P(ConvInputPaddingTest, DoTest) { .ValueOrDie(); builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(), - /*feature_group_count=*/1, window, - dnums) + /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums) .ValueOrDie(), - lhs_pad, filter, /*feature_group_count=*/1, window, dnums, - DefaultPrecisionConfig(2))); + lhs_pad, filter, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); @@ -2987,7 +3108,7 @@ class ConvFilterPaddingTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface {}; -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( ConvFilterPaddingTestCases, ConvFilterPaddingTest, ::testing::ValuesIn(std::vector{ // Can only merge interior padding on the filter's spatial dimensions; @@ -3067,11 +3188,11 @@ TEST_P(ConvFilterPaddingTest, DoIt) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), - /*feature_group_count=*/1, window, - dnums) + /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums) .ValueOrDie(), - input, rhs_pad, /*feature_group_count=*/1, window, dnums, - precision_config)); + input, rhs_pad, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums, precision_config)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); @@ -3219,13 +3340,14 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve( out_shape, input, filter, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); // TODO(b/80488902): verify this module. auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(b.Build()); - AlgebraicSimplifierOptions simplifier_options(bitcasting_callback()); + AlgebraicSimplifierOptions simplifier_options; simplifier_options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(simplifier_options); if (!simplifier.Run(module.get()).ValueOrDie()) { @@ -3431,7 +3553,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Create the reduce-window. Window window; - for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + for (int64 i = 0; i < pad->shape().rank(); ++i) { auto* dim = window.add_dimensions(); dim->set_size(1); dim->set_padding_low(10); @@ -3517,7 +3639,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { // Create the reduce-window. Window window; - for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + for (int64 i = 0; i < pad->shape().rank(); ++i) { auto* dim = window.add_dimensions(); dim->set_size(1); dim->set_padding_low(10); @@ -3592,8 +3714,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(0); builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums, DefaultPrecisionConfig(2))); std::unique_ptr dot_computation(builder.Build()); @@ -3639,12 +3761,16 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); + std::vector params; + for (int i = 0; i < 3; ++i) { + params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( + i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + } builder.AddInstruction(HloInstruction::CreateDynamicSlice( shape, builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "slice_from")), - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), + params, /*slice_sizes=*/{10, 100, 1000})); auto computation = m->AddEntryComputation(builder.Build()); @@ -3663,28 +3789,35 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000}); + std::vector slice_indices, update_indices; + for (int i = 0; i < 3; ++i) { + slice_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + update_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + i + 5, ShapeUtil::MakeShape(U32, {}), "update_indices"))); + } HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( slice_shape, builder.AddInstruction( HloInstruction::CreateParameter(0, full_shape, "slice_from")), - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), + slice_indices, /*slice_sizes=*/{10, 1, 1000})); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( slice_shape, builder.AddInstruction( - HloInstruction::CreateParameter(2, slice_shape, "to_update")), - slice, - builder.AddInstruction(HloInstruction::CreateParameter( - 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); + HloInstruction::CreateParameter(4, slice_shape, "to_update")), + slice, update_indices)); auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter()))); + GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter(), + m::Parameter(), m::Parameter()))); } // Test that two consecutive broadcasts can be merged to one. @@ -3791,7 +3924,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); @@ -3812,7 +3945,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); @@ -3827,17 +3960,38 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { param = f32[3,4] parameter(0) constant = f32[] constant(0.0) pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 - ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]} + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[4:5]} } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } +TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[3,4] parameter(0) + constant = f32[] constant(0.0) + pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); +} + TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { const char* hlo_string = R"( HloModule module @@ -3852,13 +4006,36 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Parameter())); } +TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) { + const char* hlo_string = R"( + HloModule module + + ENTRY entry () -> f32[1]{0} { + constant.val = f32[] constant(4) + constant.pad = f32[] constant(-7) + reshape.1 = f32[1,1,1]{2,1,0} reshape(f32[] constant.val) + pad = f32[3,3,3]{2,1,0} pad(f32[1,1,1]{2,1,0} reshape.1, f32[] constant.pad), padding=0_2x0_2x2_0 + slice = f32[1,1,1]{2,1,0} slice(f32[3,3,3]{2,1,0} pad), slice={[0:1], [0:1], [0:1]} + ROOT reshape.2 = f32[1]{0} reshape(f32[1,1,1]{2,1,0} slice) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::ConstantScalar(-7.0)))); +} + TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { const char* hlo_string = R"( HloModule module @@ -3874,7 +4051,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); @@ -3896,7 +4073,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); @@ -3918,7 +4095,7 @@ TEST_F(AlgebraicSimplifierTest, NegateNegate) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); @@ -3938,7 +4115,7 @@ TEST_F(AlgebraicSimplifierTest, NotNot) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); @@ -4065,9 +4242,6 @@ PadReduceWindowEffectiveBroadcastCases() { {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6}, /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true, /*should_become_broadcast=*/false}, // - {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2}, - /*reduce_window_spatials=*/{5, 5}, /*prepend_a=*/true, - /*should_become_broadcast=*/true}, // {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2}, /*reduce_window_spatials=*/{1, 1}, /*prepend_a=*/true, /*should_become_broadcast=*/false}, // @@ -4078,11 +4252,80 @@ PadReduceWindowEffectiveBroadcastCases() { return *cases; } -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( PadReduceWindowEffectiveBroadcastInstantiation, PadReduceWindowEffectiveBroadcastTest, ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases())); +class BatchDotStrengthReductionTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface< + ::testing::tuple> {}; +TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { + auto module = CreateNewVerifiedModule(); + int m, k, n; + PrimitiveType element_type; + std::tie(m, k, n, element_type) = GetParam(); + std::vector lhs_dims = {1, 3, 5}; + std::vector rhs_dims = lhs_dims; + std::vector output_dims = lhs_dims; + if (m > 0) { + lhs_dims.push_back(m); + output_dims.push_back(m); + } + if (k > 0) { + lhs_dims.push_back(k); + rhs_dims.push_back(k); + } + if (n > 0) { + rhs_dims.push_back(n); + output_dims.push_back(n); + } + Shape dot_shape = ShapeUtil::MakeShape(element_type, output_dims); + Shape lhs_shape = ShapeUtil::MakeShape(element_type, lhs_dims); + Shape rhs_shape = ShapeUtil::MakeShape(element_type, rhs_dims); + HloComputation::Builder builder(TestName()); + + auto lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, lhs_shape, "lhs")); + auto rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, rhs_shape, "rhs")); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(1); + dot_dnums.add_lhs_batch_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(2); + if (k > 0) { + dot_dnums.add_lhs_contracting_dimensions(m > 0 ? 4 : 3); + dot_dnums.add_rhs_contracting_dimensions(3); + } + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); + const bool dot_should_be_transformed = + m == 1 || k == 1 || n == 1 || m == -1 || k == -1 || n == -1; + EXPECT_EQ(changed, dot_should_be_transformed); + bool has_no_dot = true; + for (const auto& hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kDot) { + has_no_dot = false; + break; + } + } + EXPECT_EQ(has_no_dot, dot_should_be_transformed); +} + +INSTANTIATE_TEST_SUITE_P(BatchDotStrengthReductionTestInstantiation, + BatchDotStrengthReductionTest, + ::testing::Combine(::testing::Values(-1, 1, 2), + ::testing::Values(-1, 1, 2), + ::testing::Values(-1, 1, 2), + ::testing::Values(F32, BF16))); + class DotStrengthReductionTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface< @@ -4135,7 +4378,7 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { EXPECT_EQ(has_no_dot, dot_should_be_transformed); } -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( DotStrengthReductionTestInstantiation, DotStrengthReductionTest, ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Bool(), @@ -4297,9 +4540,10 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { HloInstruction* const update = builder.AddInstruction( HloInstruction::CreateParameter(1, update_shape, "update")); HloInstruction* const start_indices = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0({}))); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - dslice_shape, operand, update, start_indices)); + dslice_shape, operand, update, + std::initializer_list({start_indices}))); const HloComputation* const computation = m->AddEntryComputation(builder.Build()); @@ -4308,9 +4552,9 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { EXPECT_THAT(computation->root_instruction(), operand); } -INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, - DotOfConcatSimplificationTest, - ::testing::ValuesIn(kDotOfConcatTestSpecs)); +INSTANTIATE_TEST_SUITE_P(DotOfConcatSimplificationTestInstantiation, + DotOfConcatSimplificationTest, + ::testing::ValuesIn(kDotOfConcatTestSpecs)); struct DotOfGatherTestSpec { int64 m; @@ -4352,14 +4596,17 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int32 start_row = (spec.lcd == 0) ? 0 : spec.s; int32 start_col = (spec.lcd == 0) ? spec.s : 0; - const auto start_indices = + std::vector start_indices = { builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({start_row, start_col}))); + LiteralUtil::CreateR0(start_row))), + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(start_col)))}; int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1; int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k; - Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + std::vector slice_sizes = {slice_row_size, slice_col_size}; + Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes); auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ds_shape, lhs, start_indices, {slice_row_size, slice_col_size})); + ds_shape, lhs, start_indices, slice_sizes)); int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n; int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k; @@ -4392,7 +4639,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { } else { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), - m::Concatenate()))); + m::Constant(), m::Constant()))); } } @@ -4430,14 +4677,17 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int32 start_row = (spec.rcd == 0) ? 0 : spec.s; int32 start_col = (spec.rcd == 0) ? spec.s : 0; - const auto start_indices = + std::vector start_indices = { + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(start_row))), builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({start_row, start_col}))); + LiteralUtil::CreateR0(start_col)))}; int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1; int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k; - Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + std::vector slice_sizes = {slice_row_size, slice_col_size}; + Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes); auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ds_shape, rhs, start_indices, {slice_row_size, slice_col_size})); + ds_shape, rhs, start_indices, slice_sizes)); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(spec.lcd); @@ -4462,7 +4712,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { } else { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), - m::Concatenate()))); + m::Constant(), m::Constant()))); } } @@ -4510,9 +4760,160 @@ std::vector DotOfGatherPositiveNegativeTests() { return all; } -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest, ::testing::ValuesIn(DotOfGatherPositiveNegativeTests())); +TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) { + const char* hlo_string = R"( +HloModule module + +reducer { + parameter.1 = f32[] parameter(0) + parameter.3 = f32[] parameter(2) + add.2 = f32[] add(parameter.1, parameter.3) + parameter.0 = f32[] parameter(1) + parameter.2 = f32[] parameter(3) + add.3 = f32[] add(parameter.0, parameter.2) + ROOT tuple.4 = (f32[], f32[]) tuple(add.2, add.3) +} + +ENTRY entry { + parameter.6 = (f32[], f32[]) parameter(0) + get-tuple-element.10 = f32[] get-tuple-element(parameter.6), index=0 + get-tuple-element.11 = f32[] get-tuple-element(parameter.6), index=1 + constant = f32[] constant(0) + ROOT reduce = (f32[], f32[]) reduce(get-tuple-element.10, get-tuple-element.11, constant, constant), dimensions={}, to_apply=reducer +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Tuple( + m::Reshape(m::GetTupleElement(m::Parameter(), 0)), + m::Reshape(m::GetTupleElement(m::Parameter(), 1))))); +} + +TEST_F(AlgebraicSimplifierTest, TupleReduceBroadcast) { + const char* hlo_string = R"( +HloModule module + +reducer { + parameter.1 = f32[] parameter(0) + parameter.3 = f32[] parameter(2) + mul.2 = f32[] add(parameter.1, parameter.3) + parameter.0 = f32[] parameter(1) + parameter.2 = f32[] parameter(3) + add.3 = f32[] add(parameter.0, parameter.2) + ROOT tuple.4 = (f32[], f32[]) tuple(mul.2, add.3) +} + +ENTRY entry { + parameter.6 = (f32[0, 10, 10], f32[0, 10, 10]) parameter(0) + get-tuple-element.10 = f32[0, 10, 10] get-tuple-element(parameter.6), index=0 + get-tuple-element.11 = f32[0, 10, 10] get-tuple-element(parameter.6), index=1 + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(1) + ROOT reduce = (f32[10, 10], f32[10, 10]) reduce(get-tuple-element.10, get-tuple-element.11, constant.0, constant.1), dimensions={0}, to_apply=reducer +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Tuple(m::Broadcast(m::ConstantScalar(0)), + m::Broadcast(m::ConstantScalar(1))))); +} + +TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1}), "param")); + HloInstruction* broadcast = + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {0, 1}), param, {1})); + + // Create a reshape with zero sized result and without layout. + Shape reshaped_shape = ShapeUtil::MakeShape(F32, {0}); + reshaped_shape.clear_layout(); + builder.AddInstruction( + HloInstruction::CreateReshape(reshaped_shape, broadcast)); + + std::unique_ptr module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Constant())); +} + +TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) { + Shape shape = ShapeUtil::MakeShape(F32, {}); + shape.clear_layout(); + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + + HloInstruction* const_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(20.0f))); + builder.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kDivide, + param, const_value)); + + std::unique_ptr module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Multiply())); +} + +// Test that 1/sqrt(X) is simplified to rsqrt(X). +TEST_F(AlgebraicSimplifierTest, RecipSqrt) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + sqrt = f32[] sqrt(p0) + ROOT div = f32[] divide(p1, sqrt) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder(m::Parameter(1), + m::Rsqrt(m::Parameter(0))))); +} + +// Test that 1/rsqrt(X) is simplified to sqrt(X). +TEST_F(AlgebraicSimplifierTest, RecipRsqrt) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + rsqrt = f32[] rsqrt(p0) + ROOT div = f32[] divide(p1, rsqrt) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder(m::Parameter(1), + m::Sqrt(m::Parameter(0))))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index ef5e211646e7b0b66b8e6c09948be58063422943..6cb0e985e57016e5a22fba50c3e3ad6970f1b178 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -142,13 +142,13 @@ StatusOr> AllocationTracker::DeconstructTuple( // We only need to care about replica id 0 here, since the GlobalDataHandle is // the same for all buffers across replicas. const ShapedBuffer* shaped_buffer = replicated_buffers[0]; - if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { + if (!shaped_buffer->on_host_shape().IsTuple()) { return InvalidArgument("global data handle %d is not a tuple", data.handle()); } // If the on-host representation is a tuple, then the on-device one should be // as well. - TF_RET_CHECK(ShapeUtil::IsTuple(shaped_buffer->on_device_shape())); + TF_RET_CHECK(shaped_buffer->on_device_shape().IsTuple()); if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { return Unimplemented("Deconstructing nested tuples is not implemented."); diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index 362bc44a1cf377b51c5519c6ab5e0d9628e80e58..52d6982c70f7962ea9f54db0a4b1f2089a122c1c 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -26,38 +26,72 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -namespace { - namespace m = match; -// If the argument instruction is a CRS in the sequence -// AR -> Convert -> Add -> CRS -// then return the AR in the sequence. -// TODO(b/117554291): Rewrite this to recognize more general patterns, -// not just the specific one of AR -> Add -> Convert -> CRS. -absl::optional MatchesArCrsPattern( +// Checks if the argument instruction is an AllReduce, followed by a certain +// sequence of instructions and then a CRS. It must be possible to move +// the AR past each instruction in the sequence. Returns the CRS, which is the +// last instruction in the sequence. +absl::optional ArCrsCombiner::MatchesArCrsPattern( HloInstruction* instruction) { - HloInstruction *ar, *convert, *add, *crs; - if (Match(instruction, - m::CrossReplicaSum( - &crs, m::Add(&add, m::Op(), - m::Convert(&convert, - m::CrossReplicaSum(&ar, m::Op()))))) && - ar->users().size() == 1 && ar->shape().element_type() == BF16 && - convert->shape().element_type() == F32 && !crs->all_reduce_id()) { - return ar; + auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool { + if (instruction->user_count() != 1) { + return false; + } + switch (instruction->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kReshape: + return true; + case HloOpcode::kConvert: + // Can be moved across if both input and output is either float or + // integer (e.g. S32<->U32 or F32<->BF16) + return ShapeUtil::ElementIsFloating(instruction->shape()) == + ShapeUtil::ElementIsFloating(instruction->operand(0)->shape()); + case HloOpcode::kAdd: + case HloOpcode::kSubtract: + case HloOpcode::kMultiply: + // Only supported for floating point operands. + return ShapeUtil::ElementIsFloating(instruction->shape()); + default: + return false; + } + }; + + auto computation_is_addition = [](HloComputation* c) { + return c->instruction_count() == 3 && + Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); + }; + + if (!instruction->IsCrossModuleAllReduce() || + !computation_is_addition(instruction->called_computations()[0]) || + instruction->user_count() != 1) { + return absl::nullopt; + } + auto next = instruction->users()[0]; + int64 distance = 1; + while (!next->IsCrossReplicaAllReduce()) { + if (can_ar_move_past_instruction(next)) { + next = next->users()[0]; + } else { + return absl::nullopt; + } + ++distance; + } + if (!Cast(next)->IsNoop() && + computation_is_addition(next->called_computations()[0])) { + return absl::optional(ArCrsPair(instruction, next, distance)); + } else { + return absl::nullopt; } - return absl::optional(); } -} // namespace - absl::optional ArCrsCombiner::WhileFromBodyParameter( HloInstruction* instruction) { CHECK_EQ(HloOpcode::kParameter, instruction->opcode()); @@ -69,7 +103,7 @@ absl::optional ArCrsCombiner::WhileFromBodyParameter( return caller_instruction; } } - return absl::optional(); + return absl::nullopt; } std::vector ArCrsCombiner::GetAllTuples( @@ -160,6 +194,15 @@ bool ArCrsCombiner::InstructionsComputeSameValue( if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) { return false; } + auto eq_computations = [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }; + if (i1->IsCrossModuleAllReduce()) { + return i1->Identical(*i2, + /*eq_operands=*/std::equal_to(), + eq_computations, + /*layout_sensitive=*/false); + } visited_pairs->emplace(min_uid, max_uid); for (int i = 0; i < operands1.size(); ++i) { auto operand1 = operands1[i]; @@ -185,19 +228,61 @@ bool ArCrsCombiner::InstructionsComputeSameValue( // InstructionsComputeSameValue earlier. auto eq_instructions = [](const HloInstruction* i1, const HloInstruction* i2) -> bool { return true; }; - auto eq_computations = [](const HloComputation* a, const HloComputation* b) { - return *a == *b; - }; return i1->Identical(*i2, eq_instructions, eq_computations, /*layout_sensitive=*/false); } void ArCrsCombiner::GroupAllReducesById(HloModule* module) { + // Say that two or more ARs lead to the same CRS: (AR1, CRS), (AR2, CRS), + // ... , (ARn, CRS). + // If as we traverse the HLO graph we start tracking the pair (AR2, CRS), + // and later find that AR1's distance from the CRS is longer, we discard + // AR2 and start tracking AR1. We put the discarded ids in this set, in order + // to skip processing of short paths when we encounter the other ARs that + // have the same id as AR2. + absl::flat_hash_set discarded_ar_ids; for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { - auto ar = MatchesArCrsPattern(instruction); - if (ar) { - all_reduce_map_[*((*ar)->all_reduce_id())].push_back(*ar); + auto maybe_pair = MatchesArCrsPattern(instruction); + if (maybe_pair) { + auto pair = *maybe_pair; + int64 ar_id = *(instruction->all_reduce_id()); + if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) { + continue; + } + auto it = crs_reserved_map_.find(pair.crs); + if (it != crs_reserved_map_.end()) { + auto prev_ar_id = it->second; + // Since there is another AR paired with CRS, + // all_reduce_map_[prev_ar_id] should exist, but + // all_reduce_map_[ar_id] shouldn't. + CHECK(all_reduce_map_.find(ar_id) == all_reduce_map_.end()); + CHECK_NE(prev_ar_id, ar_id); + auto prev_pair = all_reduce_map_[prev_ar_id].back(); + int64 prev_distance = prev_pair.distance; + if (prev_distance < pair.distance) { + // The current AR's distance to CRS is longer than the previously + // tracked AR, so we discard the previous AR. + all_reduce_map_.erase(prev_ar_id); + discarded_ar_ids.insert(prev_ar_id); + all_reduce_map_[ar_id].push_back(pair); + crs_reserved_map_[pair.crs] = ar_id; + } else { + // Discard the current AR id because we are keeping the previously + // tracked AR. + discarded_ar_ids.insert(ar_id); + } + } else { + if (all_reduce_map_.find(ar_id) != all_reduce_map_.end()) { + int64 prev_distance = all_reduce_map_[ar_id].back().distance; + CHECK_EQ(prev_distance, pair.distance) + << "All ARs with the same AR ID must have the same distance " + "from the corresponding CRSs. Found: " + << prev_distance << " and " << pair.distance; + } + all_reduce_map_[ar_id].push_back(pair); + crs_reserved_map_[pair.crs] = ar_id; + } } } } @@ -205,20 +290,25 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { for (auto it : all_reduce_map_) { - auto instruction_vec = it.second; - CHECK_EQ(instruction_vec.size(), num_spatial_partitions_); - - auto instr_0 = instruction_vec[0]; - auto add_0 = instr_0->users()[0]->users()[0]; - CHECK_EQ(HloOpcode::kAdd, add_0->opcode()); - - for (int i = 1; i < instruction_vec.size(); ++i) { - auto instr_i = instruction_vec[i]; - auto add_i = instr_i->users()[0]->users()[0]; - CHECK_EQ(HloOpcode::kAdd, add_i->opcode()); + auto all_reduce_id = it.first; + auto pairs_vec = it.second; + CHECK_EQ(pairs_vec.size(), num_spatial_partitions_); + auto instr_0 = pairs_vec[0].ar; + for (int i = 1; i < pairs_vec.size(); ++i) { + auto instr_i = pairs_vec[i].ar; + auto next_0 = instr_0->users()[0]; + auto next_i = instr_i->users()[0]; absl::flat_hash_map visited_pairs; - if (!InstructionsComputeSameValue(add_0, add_i, &visited_pairs)) { - all_reduce_map_.erase(it.first); + while (true) { + if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) { + all_reduce_map_.erase(all_reduce_id); + break; + } + if (next_0->IsCrossReplicaAllReduce()) { + break; + } + next_0 = next_0->users()[0]; + next_i = next_i->users()[0]; } } } @@ -228,47 +318,59 @@ StatusOr ArCrsCombiner::RewriteGraph() { if (all_reduce_map_.empty()) { return false; } - - auto computation_is_addition = [](HloComputation* c) { - return c->instruction_count() == 3 && - Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); - }; - for (auto it : all_reduce_map_) { - auto instruction_vec = it.second; - for (auto all_reduce : instruction_vec) { + auto pairs_vec = it.second; + for (auto pair : pairs_vec) { + auto all_reduce = pair.ar; auto parent_computation = all_reduce->parent(); - auto convert = all_reduce->users()[0]; - auto add = convert->users()[0]; - auto crs = add->users()[0]; - - if (!computation_is_addition(all_reduce->called_computations()[0]) || - !computation_is_addition(crs->called_computations()[0])) { - continue; + auto all_reduce_id = all_reduce->all_reduce_id(); + auto prev = all_reduce->mutable_operand(0); + auto next = all_reduce->users()[0]; + TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev)); + TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce)); + while (!next->IsCrossReplicaAllReduce()) { + switch (next->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kReshape: + case HloOpcode::kConvert: + case HloOpcode::kMultiply: + break; + case HloOpcode::kAdd: + case HloOpcode::kSubtract: { + auto other_operand = (next->operands()[0] == prev) + ? next->operands()[1] + : next->operands()[0]; + // To move the AR past the addition/subtraction, we need to divide + // other_operand by the number of spatial partitions, except if + // other_operand is a cross-module AR, which can be eliminated. + if (other_operand->IsCrossModuleAllReduce() && + other_operand->user_count() == 1) { + TF_CHECK_OK(other_operand->ReplaceAllUsesWith( + other_operand->mutable_operand(0))); + } else { + auto shape = other_operand->shape(); + Literal lit(shape); + lit.PopulateWithValue(num_spatial_partitions_); + auto divisor = parent_computation->AddInstruction( + HloInstruction::CreateConstant(lit.Clone())); + auto division = parent_computation->AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDivide, + other_operand, divisor)); + TF_CHECK_OK(other_operand->ReplaceUseWith(next, division)); + } + break; + } + default: + LOG(FATAL) << "Unexpected instruction: " << next->ToShortString(); + } + prev = next; + next = next->users()[0]; } - HloInstruction* other_summand = (add->operands()[0] == convert) - ? add->operands()[1] - : add->operands()[0]; - // To move the AR past the addition, we need to divide other_summand by - // the number of spatial partitions. - CHECK_EQ(all_reduce->user_count(), 1); - TF_CHECK_OK( - all_reduce->ReplaceAllUsesWith(all_reduce->mutable_operand(0))); - auto shape = other_summand->shape(); - Literal lit(shape); - lit.PopulateWithValue(num_spatial_partitions_); - auto divisor = parent_computation->AddInstruction( - HloInstruction::CreateConstant(lit.Clone())); - auto division = - parent_computation->AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDivide, other_summand, divisor)); - TF_CHECK_OK(other_summand->ReplaceUseWith(add, division)); // The AllReduce and the CRS are combined to an all-core AllReduce. - crs->set_all_reduce_id(all_reduce->all_reduce_id()); - TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce)); + next->set_all_reduce_id(all_reduce_id); } } - return true; } diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h index f6a7ef76ec3b76972d1b2c7fb548cecfb9423160..f503e1d5f2b519687e40818a61f0c0be9dfd3ab0 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.h +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -25,9 +25,48 @@ limitations under the License. namespace xla { -// Combine an AllReduce and a CrossReplicaSum when they are close to each other -// in the graph, to use an efficient CrossReplicaSum implementation that -// fully utilizes the interconnect bandwidth. +// When the HLO graph contains a cross-module AllReduce, followed by some simple +// linear operations, followed by a cross-replica AllReduce (also known as +// cross-replica sum, or CRS), we can combine the CMAR and the CRAR, to use an +// efficient AllReduce implementation that fully utilizes the interconnect +// bandwidth. +// Such sequences appear in spatially partitioned models. +// This pass must run right after spatial partitioning, when the code is still +// in a single HLO module. +// +// The steps are: +// 1) Find CMARs followed by simple ops followed by CRARs. +// 2) Group CMARs by all_reduce_id. They must all be rewritten. +// 3) Prove that the CMAR patterns in each core produce the same result. +// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the +// other operand by the number of spatial partitions. +// 5) Turn the CRAR into an all-core AllReduce. +// +// The pass also handles the case where multiple CMARs lead to the same CRAR, +// and eliminates all CMARs. This graph: +// +// Y +// | +// X CMAR_2 Z +// | \ / +// CMAR_1 + +// \ / +// + +// | +// CRAR +// +// gets rewritten to: +// +// Z num_partitions +// \ / +// Y div +// \ / +// X + +// \ / +// + +// | +// all-core AR +// class ArCrsCombiner : public HloModulePass { public: ArCrsCombiner(int num_spatial_partitions) @@ -40,6 +79,28 @@ class ArCrsCombiner : public HloModulePass { HloInstruction* i2); private: + // We used this struct because multiple ARs could be paired with the same CRS. + // In this case, we want to select the AR that is furthest from the CRS, + // because it makes it easier to eliminate all ARs during RewriteGraph. + struct ArCrsPair { + HloInstruction* ar; + HloInstruction* crs; + // The length of the path from AR to CRS in the HLO graph. + int64 distance; + + ArCrsPair(HloInstruction* all_reduce, HloInstruction* cross_replica_sum, + int64 dist) + : ar(all_reduce), crs(cross_replica_sum), distance(dist) {} + + string ToString() { + return absl::StrCat("(AR: ", ar->name(), ", CRS: ", crs->name(), + ", distance: ", distance, ")"); + } + }; + + absl::optional MatchesArCrsPattern( + HloInstruction* instruction); + // If the passed instruction is a while parameter, and the while body is only // called by a single while instruction, return the while instruction. absl::optional WhileFromBodyParameter( @@ -77,8 +138,13 @@ class ArCrsCombiner : public HloModulePass { int num_spatial_partitions_; - // Map from all-reduce ids to the all reduce instructions. - absl::flat_hash_map> all_reduce_map_; + // Map from all-reduce ids to the AR/CRS pairs. + absl::flat_hash_map> all_reduce_map_; + + // Map from a CRS instruction to the all-reduce ID of the AR paired with the + // CRS. Sometimes, several ARs in the code could be paired with the same CRS. + // We use this map to pick a single AR/CRS path to rewrite. + absl::flat_hash_map crs_reserved_map_; std::unique_ptr call_graph_; }; diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index 10171835d83c75fef091a34b8fe102d263211307..9c9db74fd2fdab836f91d2f749d08ad93f8879b0 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -32,8 +32,8 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{1, 2}, {3, 4}}) ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) } )"; @@ -91,7 +91,7 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple1 = (f32[2,2]) tuple(%constant.f32) %tuple2 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %tuple = ((f32[2,2]), (f32[2,2], f32[2,2])) tuple(%tuple1, %tuple2) @@ -152,7 +152,7 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=0 @@ -174,7 +174,7 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 @@ -196,8 +196,8 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{2, 3}, {4, 5}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{2, 3}, {4, 5}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 @@ -226,7 +226,7 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) @@ -235,7 +235,7 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -263,7 +263,7 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) @@ -272,8 +272,8 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {7, 8}}) + %constant.f32.1 = f32[2,2] constant({{3, 4}, {5, 6}}) + %constant.f32.2 = f32[2,2] constant({{3, 4}, {7, 8}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -301,8 +301,8 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {1, 2}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{3, 4}, {1, 2}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32.1) @@ -311,7 +311,7 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -326,11 +326,27 @@ ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); } -TEST_F(ArCrsCombinerTest, RewritePatternArConvertAddCrs) { +void CompareReplicaGroups(const std::vector& groups_before, + const std::vector& groups_after) { + ASSERT_EQ(groups_before.size(), groups_after.size()); + for (int i = 0; i < groups_before.size(); ++i) { + // Somewhat verbose way to compare the replica_ids, because EqualsProto + // is not available in the open-source build. + auto group_before = groups_before[i]; + std::vector ids_before(group_before.replica_ids().begin(), + group_before.replica_ids().end()); + auto group_after = groups_after[i]; + std::vector ids_after(group_after.replica_ids().begin(), + group_after.replica_ids().end()); + EXPECT_EQ(ids_before, ids_after); + } +} + +TEST_F(ArCrsCombinerTest, RewriteArConvertCrs) { const char* module_str = R"( HloModule foobar -%binary_add (a: bf16[], b: bf16[]) -> bf16[] { +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { %a = bf16[] parameter(0) %b = bf16[] parameter(1) ROOT %add = bf16[] add(%a, %b) @@ -342,49 +358,258 @@ HloModule foobar ROOT %add = f32[] add(%x, %y) } -ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { - %p = f32[2,2] parameter(0) - %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) +ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { + %p = bf16[] parameter(0) + %constant.bf16 = bf16[] constant(1) + + %all-reduce.ar.1 = bf16[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=0} + %convert.1 = f32[] + convert(%all-reduce.ar.1), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%convert.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %all-reduce.ar.2 = bf16[] + all-reduce(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=1} + %convert.2 = f32[] + convert(%all-reduce.ar.2), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%convert.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Convert(op::Parameter())), + op::AllReduce(op::Convert(op::Constant())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.1 (a: f32[2,1], b: f32[2,1]) -> f32[2,1] { + %a = f32[2,1] parameter(0) + %b = f32[2,1] parameter(1) + ROOT %add = f32[2,1] add(%a, %b) +} + +%sum.2 (x: f32[2], y: f32[2]) -> f32[2] { + %x = f32[2] parameter(0) + %y = f32[2] parameter(1) + ROOT %add = f32[2] add(%x, %y) +} - %cross-replica-sum.ar.1 = bf16[2,2] - cross-replica-sum(%constant.bf16), +ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) { + %p = f32[2,1] parameter(0) + + %all-reduce.ar.1 = f32[2,1] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=0} + %bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.1) + %all-reduce.1 = f32[2] + all-reduce(%bitcast.1), + replica_groups={{0,1}}, + to_apply=%sum.2, + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[2,1] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=1} + %bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.2) + %all-reduce.2 = f32[2] + all-reduce(%bitcast.2), + replica_groups={{0,1}}, + to_apply=%sum.2, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Bitcast(op::Parameter())), + op::AllReduce(op::Bitcast(op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + + %all-reduce.ar.1 = f32[] + all-reduce(%p), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.f32, sharding={maximal device=0} - %convert.1 = f32[2,2] - convert(%cross-replica-sum.ar.1), + %multiply.1 = f32[] + multiply(%all-reduce.ar.1, %constant.f32), sharding={maximal device=0} - %add.1 = f32[2,2] + %all-reduce.1 = f32[] + all-reduce(%multiply.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=1} + %multiply.2 = f32[] + multiply(%all-reduce.ar.2, %constant.f32), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%multiply.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Multiply(op::Parameter(), op::Constant())), + op::AllReduce(op::Multiply(op::Parameter(), op::Constant())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.bf16 = bf16[] constant(1) + %constant.f32 = f32[] constant(2) + + %all-reduce.ar.1 = bf16[] + all-reduce(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=0} + %convert.1 = f32[] + convert(%all-reduce.ar.1), + sharding={maximal device=0} + %add.1 = f32[] add(%constant.f32, %convert.1), sharding={maximal device=0} - %cross-replica-sum.1 = f32[2,2] - cross-replica-sum(%add.1), + %all-reduce.1 = f32[] + all-reduce(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=0} - %cross-replica-sum.ar.2 = bf16[2,2] - cross-replica-sum(%constant.bf16), + %all-reduce.ar.2 = bf16[] + all-reduce(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=1} - %convert.2 = f32[2,2] - convert(%cross-replica-sum.ar.2), + %convert.2 = f32[] + convert(%all-reduce.ar.2), sharding={maximal device=1} - %add.2 = f32[2,2] + %add.2 = f32[] add(%constant.f32, %convert.2), sharding={maximal device=1} - %cross-replica-sum.2 = f32[2,2] - cross-replica-sum(%add.2), + %all-reduce.2 = f32[] + all-reduce(%add.2), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=1} - ROOT %tuple = (f32[2,2], f32[2,2]) - tuple(%cross-replica-sum.1, %cross-replica-sum.2), + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), sharding={{maximal device=0}, {maximal device=1}} } )"; @@ -400,32 +625,21 @@ ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { EXPECT_THAT( module->entry_computation()->root_instruction(), op::Tuple( - op::CrossReplicaSum(op::Add( - op::Divide(op::Constant(), op::Constant()), op::Convert())), - op::CrossReplicaSum(op::Add( - op::Divide(op::Constant(), op::Constant()), op::Convert())))); + op::AllReduce(op::Add(op::Divide(op::Constant(), op::Constant()), + op::Convert())), + op::AllReduce(op::Add(op::Divide(op::Constant(), op::Constant()), + op::Convert())))); auto crs_after = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_after = crs_after->replica_groups(); - ASSERT_EQ(replica_groups_before.size(), replica_groups_after.size()); - for (int i = 0; i < replica_groups_before.size(); ++i) { - // Somewhat verbose way to compare the replica_ids, because EqualsProto - // is not available in the open-source build. - auto group_before = replica_groups_before[i]; - std::vector ids_before(group_before.replica_ids().begin(), - group_before.replica_ids().end()); - auto group_after = replica_groups_after[i]; - std::vector ids_after(group_after.replica_ids().begin(), - group_after.replica_ids().end()); - EXPECT_EQ(ids_before, ids_after); - } + CompareReplicaGroups(replica_groups_before, replica_groups_after); } TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) { const char* module_str = R"( HloModule foobar -%binary_add (a: bf16[], b: bf16[]) -> bf16[] { +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { %a = bf16[] parameter(0) %b = bf16[] parameter(1) ROOT %add = bf16[] add(%a, %b) @@ -437,50 +651,517 @@ HloModule foobar ROOT %add = f32[] add(%x, %y) } -ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { - %p = f32[2,2] parameter(0) - %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.bf16 = bf16[] constant(1) + %constant.f32.1 = f32[] constant(2) + %constant.f32.2 = f32[] constant(3) - %cross-replica-sum.ar.1 = bf16[2,2] - cross-replica-sum(%constant.bf16), + %all-reduce.ar.1 = bf16[] + all-reduce(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=0} - %convert.1 = f32[2,2] - convert(%cross-replica-sum.ar.1), + %convert.1 = f32[] + convert(%all-reduce.ar.1), sharding={maximal device=0} - %add.1 = f32[2,2] + %add.1 = f32[] add(%constant.f32.1, %convert.1), sharding={maximal device=0} - %cross-replica-sum.1 = f32[2,2] - cross-replica-sum(%add.1), + %all-reduce.1 = f32[] + all-reduce(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=0} - %cross-replica-sum.ar.2 = bf16[2,2] - cross-replica-sum(%constant.bf16), + %all-reduce.ar.2 = bf16[] + all-reduce(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=1} - %convert.2 = f32[2,2] - convert(%cross-replica-sum.ar.2), + %convert.2 = f32[] + convert(%all-reduce.ar.2), sharding={maximal device=1} - %add.2 = f32[2,2] + %add.2 = f32[] add(%constant.f32.2, %convert.2), sharding={maximal device=1} - %cross-replica-sum.2 = f32[2,2] - cross-replica-sum(%add.2), + %all-reduce.2 = f32[] + all-reduce(%add.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + +TEST_F(ArCrsCombinerTest, ArThenCrsDontCrash) { + const char* module_str = R"( +HloModule foobar + +%sum.1 (a: f32[], b: f32[]) -> f32[] { + %a = f32[] parameter(0) + %b = f32[] parameter(1) + ROOT %add = f32[] add(%a, %b) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + + %all-reduce.ar.1 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%all-reduce.ar.1), + replica_groups={{0,1}}, + to_apply=%sum.1, + sharding={maximal device=0} + %multiply.1 = f32[] + multiply(%all-reduce.1, %constant.f32), + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%all-reduce.ar.2), + replica_groups={{0,1}}, + to_apply=%sum.1, + sharding={maximal device=1} + %multiply.2 = f32[] + multiply(%all-reduce.2, %constant.f32), + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Parameter()), + op::AllReduce(op::Parameter()))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteMultipleAdds) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.1 = f32[] constant(1) + %constant.2 = f32[] constant(2) + + %all-reduce.ar.1 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %add.11 = f32[] + add(%constant.1, %all-reduce.ar.1), + sharding={maximal device=0} + %add.12 = f32[] + add(%constant.2, %add.11), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%add.12), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %add.21 = f32[] + add(%constant.1, %all-reduce.ar.2), + sharding={maximal device=0} + %add.22 = f32[] + add(%constant.2, %add.21), + sharding={maximal device=0} + %all-reduce.2 = f32[] + all-reduce(%add.22), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Divide(op::Constant(), op::Constant()), + op::Add(op::Divide(op::Constant(), op::Constant()), + op::Parameter()))), + op::AllReduce(op::Add( + op::Divide(op::Constant(), op::Constant()), + op::Add(op::Divide(op::Constant(), op::Constant()), + op::Parameter()))))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArSubtractCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + + %all-reduce.ar.1 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=0} + %sub.1 = f32[] + subtract(%constant.f32, %all-reduce.ar.1), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%sub.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=1} + %sub.2 = f32[] + subtract(%constant.f32, %all-reduce.ar.2), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%sub.2), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=1} - ROOT %tuple = (f32[2,2], f32[2,2]) - tuple(%cross-replica-sum.1, %cross-replica-sum.2), + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple( + op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()), + op::Parameter())), + op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()), + op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeft) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %const1 = f32[] constant(1) + %const2 = f32[] constant(2) + + %ar11 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %add11 = f32[] + add(%ar11, %const1), + sharding={maximal device=0} + %ar12 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=0} + %add12 = f32[] + add(%add11, %ar12), + sharding={maximal device=0} + %crs1 = f32[] + all-reduce(%add12), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + %ar21 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=1} + %add21 = f32[] + add(%ar21, %const1), + sharding={maximal device=1} + %ar22 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=1} + %add22 = f32[] + add(%add21, %ar22), + sharding={maximal device=1} + %crs2 = f32[] + all-reduce(%add22), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%crs1, %crs2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())), + op::Parameter())), + op::AllReduce(op::Add( + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())), + op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteMultipleARsRight) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %const1 = f32[] constant(1) + %const2 = f32[] constant(2) + + %ar11 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %ar12 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=0} + %add11 = f32[] + add(%ar12, %const1), + sharding={maximal device=0} + %add12 = f32[] + add(%ar11, %add11), + sharding={maximal device=0} + %crs1 = f32[] + all-reduce(%add12), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + %ar21 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=1} + %ar22 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=1} + %add21 = f32[] + add(%ar22, %const1), + sharding={maximal device=1} + %add22 = f32[] + add(%ar21, %add21), + sharding={maximal device=1} + %crs2 = f32[] + all-reduce(%add22), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%crs1, %crs2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Parameter(), + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())))), + op::AllReduce(op::Add( + op::Parameter(), + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())))))); + + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { + %p = bf16[] parameter(0) + %constant.bf16 = bf16[] constant(1) + + %all-reduce.ar.1 = bf16[] + all-reduce(%p), + replica_groups={{0}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=0} + %convert.1 = f32[] + convert(%all-reduce.ar.1), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%convert.1), + replica_groups={{0}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %all-reduce.ar.2 = bf16[] + all-reduce(%constant.bf16), + replica_groups={{0}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=1} + %convert.2 = f32[] + convert(%all-reduce.ar.2), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%convert.2), + replica_groups={{0}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), sharding={{maximal device=0}, {maximal device=1}} } )"; diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 5c180cbdd492031e133b81149f0f4698619b7788..d016d3e03d5e994841b81cda6214b6ff7cb550be 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/byte_order.h" @@ -57,18 +56,48 @@ int BackendOptions::intra_op_parallelism_threads() const { return intra_op_parallelism_threads_; } +BackendOptions& BackendOptions::set_allowed_devices( + const absl::optional>& allowed_devices) { + allowed_devices_ = allowed_devices; + return *this; +} + +const absl::optional>& BackendOptions::allowed_devices() const { + return allowed_devices_; +} + +namespace { + +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + explicit EigenThreadPoolWrapper(tensorflow::thread::ThreadPool* pool) + : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} + + void Schedule(std::function fn) override { + pool_->Schedule(std::move(fn)); + } + int NumThreads() const override { return pool_->NumThreads(); } + int CurrentThreadId() const override { return pool_->CurrentThreadId(); } + + private: + tensorflow::thread::ThreadPool* pool_ = nullptr; +}; + +} // namespace + // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. -struct Backend::EigenThreadPoolWrapper { - explicit EigenThreadPoolWrapper(const int num_threads) +struct Backend::IntraOpThreadPool { + explicit IntraOpThreadPool(const int num_threads) : pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(), "XLAEigen", num_threads)), - wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())), + wrapper(new EigenThreadPoolWrapper(pool.get())), device(new Eigen::ThreadPoolDevice(wrapper.get(), wrapper->NumThreads())) {} std::unique_ptr pool; - std::unique_ptr wrapper; + std::unique_ptr wrapper; std::unique_ptr device; }; @@ -76,8 +105,9 @@ struct Backend::EigenThreadPoolWrapper { const BackendOptions& options) { se::Platform* platform = options.platform(); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); - TF_ASSIGN_OR_RETURN(auto stream_executors, - PlatformUtil::GetStreamExecutors(platform)); + TF_ASSIGN_OR_RETURN( + auto stream_executors, + PlatformUtil::GetStreamExecutors(platform, options.allowed_devices())); TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto computation_placer, @@ -104,12 +134,10 @@ StatusOr Backend::BorrowStream(int device_ordinal) { StatusOr Backend::BorrowStream(se::StreamExecutor* executor) { tensorflow::mutex_lock l(mu_); - if (0 == stream_pools_.count(executor)) { - stream_pools_.emplace(std::piecewise_construct, - std::forward_as_tuple(executor), - std::forward_as_tuple()); + if (!stream_pools_.contains(executor)) { + stream_pools_.emplace(executor, absl::make_unique()); } - return stream_pools_.at(executor).BorrowStream(executor); + return stream_pools_.at(executor)->BorrowStream(executor); } Backend::Backend(se::Platform* platform, Compiler* compiler, @@ -137,8 +165,7 @@ Backend::Backend(se::Platform* platform, Compiler* compiler, const int num_threads = intra_op_parallelism_threads > 0 ? intra_op_parallelism_threads : tensorflow::port::NumSchedulableCPUs(); - intra_op_thread_pool_wrapper_.reset( - new EigenThreadPoolWrapper(num_threads)); + intra_op_thread_pool_.reset(new IntraOpThreadPool(num_threads)); } } @@ -150,17 +177,17 @@ int Backend::default_device_ordinal() const { const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device() const { - if (intra_op_thread_pool_wrapper_ == nullptr) { + if (intra_op_thread_pool_ == nullptr) { return nullptr; } - return intra_op_thread_pool_wrapper_->device.get(); + return intra_op_thread_pool_->device.get(); } tensorflow::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const { - if (intra_op_thread_pool_wrapper_ == nullptr) { + if (intra_op_thread_pool_ == nullptr) { return nullptr; } - return intra_op_thread_pool_wrapper_->pool.get(); + return intra_op_thread_pool_->pool.get(); } StatusOr Backend::stream_executor( diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index a2dafbe803f8bd5f23e4e9f3f6d3e6f744c9fab9..e7f29a044b95015aa7e547373c24971646833280 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -18,9 +18,11 @@ limitations under the License. #include #include +#include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -53,9 +55,16 @@ class BackendOptions { BackendOptions& set_intra_op_parallelism_threads(int num_threads); int intra_op_parallelism_threads() const; + // Sets the allowed_devices for selectively constructing stream executors + // on the platform. + BackendOptions& set_allowed_devices( + const absl::optional>& allowed_devices); + const absl::optional>& allowed_devices() const; + private: se::Platform* platform_ = nullptr; int intra_op_parallelism_threads_ = -1; + absl::optional> allowed_devices_; }; // Class which encapsulates an XLA backend. It includes everything necessary @@ -147,7 +156,6 @@ class Backend { Status ResetDevices(); private: - struct EigenThreadPoolWrapper; Backend(se::Platform* platform, Compiler* compiler, absl::Span stream_executors, TransferManager* transfer_manager, @@ -167,13 +175,15 @@ class Backend { tensorflow::mutex mu_; // Mapping from stream executor to stream pools, used by `BorrowStream` above. - std::map stream_pools_ GUARDED_BY(mu_); + absl::flat_hash_map> + stream_pools_ GUARDED_BY(mu_); // The default memory allocator to use. std::unique_ptr memory_allocator_; // For the CPU backend, an Eigen threadpool device for use by Eigen code. - std::unique_ptr intra_op_thread_pool_wrapper_; + struct IntraOpThreadPool; + std::unique_ptr intra_op_thread_pool_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index eda026ac5685dc469a6230094eb28b3618e36400..dbabd82dd55465dd4c85a56aea849a3e3702d6bf 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -28,6 +28,13 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( *rhs = batch_dot->mutable_operand(1); const Shape& lhs_shape = lhs->shape(); + // A dot with no contracting dims will be rewritten into a multiply by + // AlgebraicSimplifier. Dots with multiple contracting dims are currently + // unsupported. + if (dim_numbers.lhs_contracting_dimensions_size() != 1) { + return false; + } + std::vector degenerate_dims; for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) { if (lhs_shape.dimensions(batch_dim) == 1) { diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index 52ec1a794c5e9f4452a4bf2b648f453d8acfe976..a81f394a38f091b89b7f1e4d26653ff549f35b75 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -169,5 +169,47 @@ main { /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2))); } +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDimsNonContracting) { + const char* hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,101] parameter(0) + b = f32[1,101] parameter(1) + ROOT dot = f32[1,101,101] dot(a,b), lhs_batch_dims={0}, + lhs_contracting_dims={}, + rhs_batch_dims={0}, + rhs_contracting_dims={} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); + BatchDotSimplification pass; + ASSERT_FALSE(pass.Run(m.get()).ValueOrDie()); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDimsMultipleContracting) { + const char* hlo_text = R"( +HloModule BatchDot + +main { + lhs = f32[1,5,17,10,13] parameter(0) + rhs = f32[1,9,10,13,6,5] parameter(1) + ROOT dot = f32[10,1,17,9,6] dot(lhs,rhs), lhs_batch_dims={3,0}, + rhs_batch_dims={2,0}, + lhs_contracting_dims={1,4}, + rhs_contracting_dims={5,3} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); + BatchDotSimplification pass; + ASSERT_FALSE(pass.Run(m.get()).ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 0e6ca1871b379a2f55b92207133822fc6258b007..620876c264ad446542e3ad8229593c1f56c94604 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -95,15 +95,8 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { HloInstruction* operand, const std::function)>& add_instruction) { - HloInstruction* exponent = add_instruction(HloInstruction::CreateBroadcast( - operand->shape(), - add_instruction(HloInstruction::CreateConvert( - ShapeUtil::MakeShape(operand->shape().element_type(), {}), - add_instruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(-0.5f))))), - {})); - return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower, - operand, exponent); + return HloInstruction::CreateUnary(operand->shape(), HloOpcode::kRsqrt, + operand); } std::unique_ptr Mean( @@ -123,7 +116,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { auto elements_per_feature_u32 = add_instruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); - for (int64 i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + for (int64 i = 0; i < operand->shape().rank(); ++i) { if (i == feature_index) { continue; } @@ -229,7 +222,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); std::vector dimensions_without_feature; - for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + for (int64 i = 0; i < operand_shape.rank(); ++i) { if (i != feature_index) { dimensions_without_feature.push_back(i); } @@ -357,7 +350,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( std::vector dimensions_without_feature; - for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + for (int64 i = 0; i < operand_shape.rank(); ++i) { if (i != feature_index) { dimensions_without_feature.push_back(i); } @@ -494,7 +487,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( std::vector dimensions_without_feature; - for (int64 i = 0; i < ShapeUtil::Rank(activation_shape); ++i) { + for (int64 i = 0; i < activation_shape.rank(); ++i) { if (i != feature_index) { dimensions_without_feature.push_back(i); } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index e9d30fc03c1c3194de577e6683b36a95641694d9..e62d72b323bd1d113e9d87bf8602bfb434c40d61 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -34,8 +34,8 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; - // Special handling for cross-replica-sum which can have a tuple output. - Status HandleCrossReplicaSum(HloInstruction* crs) override; + // Special handling for all-reduce which can have a tuple output. + Status HandleAllReduce(HloInstruction* crs) override; static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { @@ -176,8 +176,7 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { return TryFoldBF16Conversions(hlo); } -Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( - HloInstruction* crs) { +Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) { if (crs->IsCrossModuleAllReduce()) { // Cross-module all-reduce has side effect. return Status::OK(); @@ -191,7 +190,7 @@ Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( } // If the output is not a tuple, we don't need special handling. - if (!ShapeUtil::IsTuple(crs->shape())) { + if (!crs->shape().IsTuple()) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 4ce351acc2c359773e618da70360c96faf5ca379..2232a2cbdfe0cf64dc4fb10d4598c0ad8b51ee5e 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -38,7 +38,7 @@ class TestBFloat16Support : public BFloat16Support { hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement || - hlo.opcode() == HloOpcode::kCrossReplicaSum) { + hlo.opcode() == HloOpcode::kAllReduce) { return true; } return false; @@ -49,7 +49,7 @@ class TestBFloat16Support : public BFloat16Support { hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement || - hlo.opcode() == HloOpcode::kCrossReplicaSum) { + hlo.opcode() == HloOpcode::kAllReduce) { return true; } return false; @@ -58,7 +58,7 @@ class TestBFloat16Support : public BFloat16Support { bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement || - hlo.opcode() == HloOpcode::kCrossReplicaSum) { + hlo.opcode() == HloOpcode::kAllReduce) { return true; } return false; @@ -213,7 +213,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { EXPECT_EQ(tuple->operand(1), convert0); } -TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { +TEST_F(BFloat16ConversionFoldingTest, FoldAllReduceTupleOutput) { auto builder = HloComputation::Builder(TestName()); auto module = CreateNewVerifiedModule(); @@ -236,11 +236,10 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateParameter(1, f32_shape, "b")); - HloInstruction* crs = - builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, - sum, /*replica_groups=*/{}, /*barrier=*/"", - /*all_reduce_id=*/absl::nullopt)); + HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce( + ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum, + /*replica_groups=*/{}, /*barrier=*/"", + /*all_reduce_id=*/absl::nullopt)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index b8a8f844eff17a95d4073f53495e0027c481f558..d1b14d604f0559b6b18f7d1fba127669c241c8a3 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -362,8 +362,8 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { } // TODO(b/112040122): Correctly normalize variadic reduce. if ((hlo->opcode() == HloOpcode::kSort || - hlo->opcode() == HloOpcode::kCrossReplicaSum) && - ShapeUtil::IsTuple(hlo->shape())) { + hlo->opcode() == HloOpcode::kAllReduce) && + hlo->shape().IsTuple()) { return HandleMultipleOutputs(hlo); } return HandleInstruction(hlo); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 9f97d18c565c7915b9f9346f0c6330cdc3c707e9..2caa979745b3b40817acb1b6951e1de5ffa294a4 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "tensorflow/compiler/xla/service/bfloat16_support.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -232,7 +233,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32); } -TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) { auto module = CreateNewVerifiedModule(); HloComputation::Builder sum_builder("sum"); auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( @@ -253,11 +254,10 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateParameter(1, bf16_shape, "b")); - HloInstruction* crs = - builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, - /*replica_groups=*/{}, /*barrier=*/"", - /*all_reduce_id=*/absl::nullopt)); + HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce( + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, + /*replica_groups=*/{}, /*barrier=*/"", + /*all_reduce_id=*/absl::nullopt)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); @@ -283,8 +283,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { HloInstruction* value = builder.AddInstruction( HloInstruction::CreateParameter(1, s32_shape, "value")); - HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, {value})); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, + MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), + {key, value}, 0, /*is_stable=*/false, &builder, + module.get())); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); @@ -309,8 +312,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { HloInstruction* value = builder.AddInstruction( HloInstruction::CreateParameter(1, bf16_shape, "value")); - HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({bf16_shape, bf16_shape}), 0, key, {value})); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, + MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}), + {key, value}, 0, /*is_stable=*/false, &builder, + module.get())); auto computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 63d4572f2028c462df1cac9d5e4ee616e407f37b..bab63f66d83b712d756078bef84926eed235f6b5 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -276,8 +276,8 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision( *use.instruction, use.operand_number)) { if (use.instruction->opcode() == HloOpcode::kTuple || - (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && - ShapeUtil::IsTuple(use.instruction->shape()))) { + (use.instruction->opcode() == HloOpcode::kAllReduce && + use.instruction->shape().IsTuple())) { ShapeIndex use_output_index{use.operand_number}; for (int64 i : use.operand_index) { use_output_index.push_back(i); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 5be7141aae423adb4fe2f39262e463ff25ae8234..a9b5d9916e400b39039248098c22a715e44ccfd2 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -209,7 +209,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) { rb.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")))); auto reduction = module->AddEmbeddedComputation(rb.Build()); HloInstruction* all_reduce = - builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( + builder.AddInstruction(HloInstruction::CreateAllReduce( ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction, /*replica_groups=*/{}, /*barrier=*/"", /*all_reduce_id=*/1)); HloInstruction* gte0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 8d7c62447852fd946440c41389300a92377c471f..cbebbdc8a2d7d0b65f12accbe424bea383ff5355 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -86,10 +86,9 @@ std::vector ColorInterferenceGraph( // first, but it would be good to investigate other ordering heuristics too. std::vector nodes(node_count); std::iota(nodes.begin(), nodes.end(), 0); - std::sort(nodes.begin(), nodes.end(), - [&interference_map](const int64 i, const int64 j) { - return interference_map[i].size() > interference_map[j].size(); - }); + absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) { + return interference_map[i].size() > interference_map[j].size(); + }); const int64 kColorUnassigned = -1; std::vector assigned_colors(node_count, kColorUnassigned); @@ -138,8 +137,8 @@ Status GatherComputationsByAllocationType( worklist.pop_front(); const HloComputation* computation = worklist_front.first; bool is_thread_local = worklist_front.second; - bool in_thread_local_set = thread_local_set.count(computation) > 0; - bool in_global_set = global_set.count(computation) > 0; + bool in_thread_local_set = thread_local_set.contains(computation); + bool in_global_set = global_set.contains(computation); // If the computation has already been added to the respective set, then // nothing to do. @@ -186,12 +185,13 @@ Status GatherComputationsByAllocationType( worklist.push_back(std::make_pair(subcomputation, false)); // Not thread local. break; - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: + case HloOpcode::kSort: case HloOpcode::kFusion: // Map/reduce etc computations are always thread-local. worklist.push_back(std::make_pair(subcomputation, @@ -207,9 +207,9 @@ Status GatherComputationsByAllocationType( // Add the computations to the vectors in post order. for (auto* computation : module->MakeComputationPostOrder()) { - if (thread_local_set.count(computation) > 0) { + if (thread_local_set.contains(computation)) { thread_local_computations->push_back(computation); - } else if (global_set.count(computation) > 0) { + } else if (global_set.contains(computation)) { global_computations->push_back(computation); } // If the computation is not reachable from the entry computation, then it @@ -219,13 +219,6 @@ Status GatherComputationsByAllocationType( return Status::OK(); } -size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { - uint64 h = std::hash()(s.index()); - h = tensorflow::Hash64Combine(h, std::hash()(s.offset())); - h = tensorflow::Hash64Combine(h, std::hash()(s.size())); - return h; -} - string BufferAllocation::Slice::ToString() const { return absl::StrCat("{index:", index(), ", offset:", offset_, ", size:", size_, "}"); @@ -240,7 +233,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice( void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size) { VLOG(4) << "Trying to add " << buffer << " to allocation #" << index(); - CHECK(assigned_buffers_.count(&buffer) == 0) + CHECK(!assigned_buffers_.contains(&buffer)) << "LogicalBuffer " << buffer << " already assigned to allocation " << index_; CHECK_LE(offset, size_) << "LogicalBuffer " << buffer @@ -279,11 +272,12 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto_assigned->set_offset(buffer_offset_size.second.offset); proto_assigned->set_size(buffer_offset_size.second.size); } - std::sort(proto.mutable_assigned()->begin(), proto.mutable_assigned()->end(), - [](const BufferAllocationProto::Assigned& assign1, - const BufferAllocationProto::Assigned& assign2) { - return assign1.logical_buffer_id() < assign2.logical_buffer_id(); - }); + absl::c_sort(*proto.mutable_assigned(), + [](const BufferAllocationProto::Assigned& assign1, + const BufferAllocationProto::Assigned& assign2) { + return assign1.logical_buffer_id() < + assign2.logical_buffer_id(); + }); return proto; } @@ -315,10 +309,10 @@ string BufferAllocation::ToString() const { for (const auto& buffer_offset_size : assigned_buffers_) { sorted_buffers.push_back(buffer_offset_size.first); } - std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); + absl::c_sort(sorted_buffers, + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); for (const LogicalBuffer* buffer : sorted_buffers) { const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); StrAppend(&output, absl::StrFormat( @@ -346,7 +340,7 @@ const PointsToSet& BufferAssignment::GetPointsToSet( bool BufferAssignment::HasAllocation(const LogicalBuffer& buffer) const { TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer)); - return allocation_index_for_buffer_.count(&buffer) > 0; + return allocation_index_for_buffer_.contains(&buffer); } const BufferAllocation& BufferAssignment::GetAssignedAllocation( @@ -401,7 +395,7 @@ bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction, const ShapeIndex& index) const { for (const LogicalBuffer* buffer : GetPointsToSet(instruction).element(index)) { - if (allocation_index_for_buffer_.count(buffer) > 0) { + if (allocation_index_for_buffer_.contains(buffer)) { return true; } } @@ -459,8 +453,7 @@ bool BufferAssignment::SharesSliceAtIndex( bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, const HloInstruction* hlo_b) const { - using SliceSet = - flat_hash_set; + using SliceSet = flat_hash_set; // Gets the slices all of instr's subshapes. If any subshape doesn't have an // assigned slice, returns the empty set. auto collect_slices = [&](const HloInstruction* instr) -> SliceSet { @@ -487,10 +480,9 @@ bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, // didn't return the empty set) for both HLOs, and the two resulting sets of // slices are disjoint. return !slices_a.empty() && !slices_b.empty() && - std::none_of(slices_a.begin(), slices_a.end(), - [&](const BufferAllocation::Slice& slice) { - return slices_b.count(slice) > 0; - }); + absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) { + return slices_b.contains(slice); + }); } StatusOr @@ -519,7 +511,7 @@ BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer, void BufferAssignment::AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer, int64 offset, int64 size) { - CHECK_EQ(0, allocation_index_for_buffer_.count(&buffer)) + CHECK(!allocation_index_for_buffer_.contains(&buffer)) << "LogicalBuffer " << buffer << " already has an allocation."; CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty()) << "Non-reusable allocation already assigned a buffer: " @@ -761,7 +753,8 @@ namespace { bool MayInterfereAcrossSubcomputations(BufferAssignment* assignment, const LogicalBuffer& a_buffer, const LogicalBuffer& b_buffer) { - auto call_graph = assignment->liveness().hlo_ordering().call_graph(); + const CallGraph& call_graph = + assignment->liveness().hlo_ordering().call_graph(); const HloInstruction* a_ancestor; const HloInstruction* b_ancestor; std::tie(a_ancestor, b_ancestor) = @@ -960,35 +953,35 @@ Status BufferAssigner::AssignBuffersForComputation( // operands (assuming operands are the same/larger size) enabling the // important reuse case where an elementwise instruction reuses one of its // operand's buffer. This improves locality. - std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [has_sequential_order, &liveness, &post_order_position, assignment]( - const LogicalBuffer* a, const LogicalBuffer* b) { - // Primary sort is by decreasing buffer size. - const int64 a_size = assignment->buffer_size_(*a); - const int64 b_size = assignment->buffer_size_(*b); - if (a_size != b_size) { - return a_size > b_size; // use ">" for decreasing size. - } - // Otherwise live out buffers come before others, if the - // instructions are sequentially ordered. - if (has_sequential_order) { - const bool a_live_out = liveness.MaybeLiveOut(*a); - const bool b_live_out = liveness.MaybeLiveOut(*b); - if (a_live_out != b_live_out) { - return a_live_out; - } - } - // Final tiebreaker is in instruction post order. - return post_order_position.at(a->instruction()) < - post_order_position.at(b->instruction()); - }); + absl::c_sort(sorted_buffers, + [has_sequential_order, &liveness, &post_order_position, + assignment](const LogicalBuffer* a, const LogicalBuffer* b) { + // Primary sort is by decreasing buffer size. + const int64 a_size = assignment->buffer_size_(*a); + const int64 b_size = assignment->buffer_size_(*b); + if (a_size != b_size) { + return a_size > b_size; // use ">" for decreasing size. + } + // Otherwise live out buffers come before others, if the + // instructions are sequentially ordered. + if (has_sequential_order) { + const bool a_live_out = liveness.MaybeLiveOut(*a); + const bool b_live_out = liveness.MaybeLiveOut(*b); + if (a_live_out != b_live_out) { + return a_live_out; + } + } + // Final tiebreaker is in instruction post order. + return post_order_position.at(a->instruction()) < + post_order_position.at(b->instruction()); + }); // BufferAllocations are necessarily created in decreasing size order. Keep // indices of previously created BufferAllocations in allocation_indices. std::vector allocation_indices; for (const LogicalBuffer* buffer : sorted_buffers) { VLOG(3) << "Assigning allocation to: " << *buffer; - if (colocated_buffers.count(buffer) > 0) { + if (colocated_buffers.contains(buffer)) { // Colocated buffers are currently assigned in an earlier pass. VLOG(3) << "Skipping colocated buffer: " << *buffer; continue; @@ -1020,10 +1013,14 @@ Status BufferAssigner::AssignBuffersForComputation( // callers. BufferAllocation* allocation = assignment->NewAllocation(*buffer, buffer_size); + bool parameter_has_alias = + assignment->module().input_output_alias_config().ParameterHasAlias( + instruction->parameter_number(), buffer->index()); allocation->set_entry_computation_parameter( - instruction->parameter_number(), buffer->index()); - VLOG(3) << "New allocation #" << allocation->index() - << " for entry computation parameter: " << *buffer; + instruction->parameter_number(), buffer->index(), + parameter_has_alias); + VLOG(3) << "Mark allocation #" << allocation->index() + << " as entry computation parameter: " << *buffer; continue; } @@ -1036,7 +1033,7 @@ Status BufferAssigner::AssignBuffersForComputation( continue; } - if (ShapeUtil::IsTuple(buffer->shape())) { + if (buffer->shape().IsTuple()) { BufferAllocation* allocation = assignment->NewAllocation(*buffer, buffer_size); allocation->set_is_tuple(true); @@ -1056,7 +1053,7 @@ Status BufferAssigner::AssignBuffersForComputation( assignment->GetAllSlices(operand, /*index=*/{})) { BufferAllocation* allocation = assignment->GetMutableAllocation(operand_slice.index()); - if (colocated_allocations.count(allocation->index()) == 0) { + if (!colocated_allocations.contains(allocation->index())) { // TODO(b/32491382) Colocated buffers are currently assigned in an // earlier pass, and so can break the "increasing allocation size" // invariant in this function (causing this CHECK to fail). However, @@ -1087,7 +1084,7 @@ Status BufferAssigner::AssignBuffersForComputation( // Instructions are iterated in increasing buffer size, so any // previously create allocation must be large enough to hold this // instruction's output (with the exception of colocated buffers). - if (colocated_allocations.count(allocation->index()) == 0) { + if (!colocated_allocations.contains(allocation->index())) { // TODO(b/32491382) Colocated buffers are currently assigned in an // earlier pass, and so can break the "increasing allocation size" // invariant in this function (causing this CHECK to fail). However, @@ -1313,10 +1310,10 @@ std::vector ComputePeakMemoryLogicalBuffers( live_buffers.end()); // Stabily sort the live buffers. - std::sort(live_buffers_vector.begin(), live_buffers_vector.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); + absl::c_sort(live_buffers_vector, + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); return live_buffers_vector; } @@ -1376,7 +1373,7 @@ void BufferAssigner::AddSetToColocatedBufferSets( std::vector overlap_set_indices; for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) { for (const LogicalBuffer* buffer : colocated_set) { - if ((*colocated_buffer_sets)[index].count(buffer) > 0) { + if ((*colocated_buffer_sets)[index].contains(buffer)) { VLOG(5) << "Found overlap with existing set on buffer " << buffer->ToString() << "\n" << ColocatedBufferSetsToString((*colocated_buffer_sets)[index], @@ -1425,12 +1422,14 @@ BufferAssigner::MergeColocatedBufferSets( << colocated_buffer_sets.size(); // Returns true if the given buffer is for the entry parameter. - auto is_entry_parameter = [](const LogicalBuffer& buffer) { + auto is_readonly_entry_parameter = [](const LogicalBuffer& buffer) { auto* instruction = buffer.instruction(); auto* computation = instruction->parent(); auto* module = computation->parent(); return instruction->opcode() == HloOpcode::kParameter && - computation == module->entry_computation(); + computation == module->entry_computation() && + !module->input_output_alias_config().ParameterHasAlias( + instruction->parameter_number(), buffer.index()); }; std::vector set_can_be_merged(colocated_buffer_sets.size(), true); @@ -1452,7 +1451,7 @@ BufferAssigner::MergeColocatedBufferSets( for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { for (auto& buffer : colocated_buffer_sets[i]) { if (buffer_liveness.MaybeLiveOut(*buffer) || - is_entry_parameter(*buffer) || + is_readonly_entry_parameter(*buffer) || buffer->instruction()->opcode() == HloOpcode::kConstant) { set_can_be_merged[i] = false; break; @@ -1539,15 +1538,16 @@ void BufferAssigner::BuildColocatedBufferSets( VLOG(4) << "Input/Output Alias Config: "; VLOG(4) << module->input_output_alias_config(); module->input_output_alias_config().ForEachAlias( - [&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { std::vector colocated_set; AddBufferToColocatedSet(module->entry_computation()->root_instruction(), output_index, points_to_analysis, &colocated_set); AddBufferToColocatedSet( - module->entry_computation()->parameter_instruction(param_number), - param_index, points_to_analysis, &colocated_set); + module->entry_computation()->parameter_instruction( + alias.parameter_number), + alias.parameter_index, points_to_analysis, &colocated_set); AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); }); @@ -1741,10 +1741,6 @@ void BufferAssigner::AssignColocatedBufferSets( // module-level scope, we can allow buffers to be shared across // computations (in some cases). allocation = assignment->NewAllocation(*buffer, buffer_size); - if (entry_parameter_number >= 0) { - allocation->set_entry_computation_parameter( - entry_parameter_number, *entry_parameter_shape_idx); - } if (is_constant) { allocation->set_constant(true); } @@ -1758,6 +1754,16 @@ void BufferAssigner::AssignColocatedBufferSets( } colocated_buffers->insert(buffer); } + + // If an allocation contains a parameter, set corresponding fields. + if (entry_parameter_number >= 0) { + bool parameter_has_alias = + assignment->module().input_output_alias_config().ParameterHasAlias( + entry_parameter_number, *entry_parameter_shape_idx); + allocation->set_entry_computation_parameter(entry_parameter_number, + *entry_parameter_shape_idx, + parameter_has_alias); + } } } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 0a9fdede803e84ca42472259084615c031b206eb..448dec3b1aa0c0f85e1060a70e965fcf3952c320 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -96,7 +96,11 @@ class BufferAllocation { // Whether this allocation is readonly i.e. backed by memory we cannot write // to. bool is_readonly() const { - return is_entry_computation_parameter() || is_constant(); + // Entry parameters are generally readonly, except when they are aliased + // with any output. + return (is_entry_computation_parameter() && + !is_parameter_aliased_with_output_) || + is_constant(); } bool is_tuple() const { return is_tuple_; } @@ -186,9 +190,10 @@ class BufferAllocation { end > other.offset_; } - struct Hasher { - size_t operator()(Slice s) const; - }; + template + friend H AbslHashValue(H h, const Slice& s) { + return H::combine(std::move(h), s.index(), s.offset(), s.size()); + } string ToString() const; @@ -273,8 +278,10 @@ class BufferAllocation { void AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size); void set_entry_computation_parameter(int64 parameter_number, - ShapeIndex param_shape_index) { + ShapeIndex param_shape_index, + bool parameter_aliased_with_output) { is_entry_computation_parameter_ = true; + is_parameter_aliased_with_output_ = parameter_aliased_with_output; parameter_number_ = parameter_number; param_shape_index_ = std::move(param_shape_index); } @@ -304,6 +311,9 @@ class BufferAllocation { // outlast the computation. bool is_entry_computation_parameter_ = false; + // Whether this entry computation parameter is aliased with output. + bool is_parameter_aliased_with_output_ = false; + // If this allocation holds an entry computation parameter, this field // indicates the index (starting from 0) of the parameter. int64 parameter_number_ = 0; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 8f482e6ba8c3e71c9980be5e6947ea61f3b4ef29..580bc2f43384006eab8711490689a200fc887d37 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" @@ -309,7 +310,7 @@ class BufferAssignmentTest : public HloTestBase { static bool BuffersDistinct(const std::vector& a, const std::vector& b, const BufferAssignment& assignment) { - std::set a_slices; + absl::flat_hash_set a_slices; for (const HloInstruction* instruction : a) { if (assignment.HasTopLevelAllocation(instruction)) { a_slices.insert( @@ -319,8 +320,8 @@ static bool BuffersDistinct(const std::vector& a, for (const HloInstruction* instruction : b) { if (assignment.HasTopLevelAllocation(instruction)) { - if (a_slices.count(assignment.GetUniqueTopLevelSlice(instruction) - .ConsumeValueOrDie())) { + if (a_slices.contains(assignment.GetUniqueTopLevelSlice(instruction) + .ConsumeValueOrDie())) { return false; } } @@ -464,6 +465,40 @@ TEST_F(BufferAssignmentTest, Basic) { GetAssignedOutputAllocation(*buffers, sub); } +TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) { + // If an input buffer and output buffer aliases, the input buffer can be + // reused for other intermediate results. + // + // param0[100] ----- (neg1) -- (neg2) + // | | + // + -------- Aliased ---------+ + + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "p0")); + auto neg_1 = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param)); + auto neg_2 = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, neg_1)); + + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias( + {}, 0, {}, HloInputOutputAliasConfig::kUserAlias)); + + auto buffers = RunBufferAssignment(module.get()); + + BufferAllocation param_buffer = GetAssignedInputAllocation(*buffers, param); + BufferAllocation neg_1_buffer = GetAllocation(*buffers, neg_1, {}); + BufferAllocation neg_2_buffer = GetAllocation(*buffers, neg_2, {}); + + // Everything use one buffer. + EXPECT_EQ(param_buffer.index(), neg_1_buffer.index()); + EXPECT_EQ(neg_2_buffer.index(), neg_1_buffer.index()); +} + TEST_F(BufferAssignmentTest, AddCannotReuse) { // Pass in a special rule to indicate that "add" cannot reuse any buffer. // @@ -2485,9 +2520,9 @@ while_body { get-tuple-element.3 = s32[] get-tuple-element(state), index=0 constant.2 = s32[] constant(128) add.5 = s32[] add(get-tuple-element.3, constant.2) - constant.3 = s32[3]{0} constant({0, 0, 0}) - dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3) - dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3) + constant.3 = s32[] constant(0) + dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3) ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9) } diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 40825a78716b1c0b9fb0121787977d275891c0f8..23b9af0281b0d5ee1ef6ca2315f0cc1042285609 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -52,8 +52,8 @@ class BufferLivenessTest : public HloTestBase { // interfere. Precondition: 'a' and 'b' are array-shaped. bool InstructionsMayInterfere(const BufferLiveness& liveness, HloInstruction* a, HloInstruction* b) { - EXPECT_FALSE(ShapeUtil::IsTuple(a->shape())); - EXPECT_FALSE(ShapeUtil::IsTuple(b->shape())); + EXPECT_FALSE(a->shape().IsTuple()); + EXPECT_FALSE(b->shape().IsTuple()); return liveness.MayInterfere( GetBuffer(liveness, /*instruction=*/a, /*index=*/{}), GetBuffer(liveness, /*instruction=*/b, /*index=*/{})); @@ -66,8 +66,8 @@ class BufferLivenessTest : public HloTestBase { HloInstruction* a, HloInstruction* b, const ShapeIndex& index) { // Check that top-level shapes are tuple and tuple element shapes are equal. - EXPECT_TRUE(ShapeUtil::IsTuple(a->shape())); - EXPECT_TRUE(ShapeUtil::IsTuple(b->shape())); + EXPECT_TRUE(a->shape().IsTuple()); + EXPECT_TRUE(b->shape().IsTuple()); EXPECT_TRUE( ShapeUtil::Compatible(ShapeUtil::GetSubshape(a->shape(), index), ShapeUtil::GetSubshape(b->shape(), index))); @@ -638,10 +638,10 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, {starts})); // Create output tuple. builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -794,10 +794,10 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, {starts})); // Create output tuple. auto tuple_root = builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc index fdf822c666b15afbc7553ca89d4f92ab08201869..b1abba20689915b03304aacd7a5fcca5443c2c60 100644 --- a/tensorflow/compiler/xla/service/buffer_value.cc +++ b/tensorflow/compiler/xla/service/buffer_value.cc @@ -29,8 +29,8 @@ BufferValue::BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id) : id_(id) { const Shape& shape = ShapeUtil::GetSubshape(instruction->shape(), index); - is_array_ = ShapeUtil::IsArray(shape); - is_tuple_ = ShapeUtil::IsTuple(shape); + is_array_ = shape.IsArray(); + is_tuple_ = shape.IsTuple(); } BufferValue::~BufferValue() {} diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 7987343bfaf1069fd550909d127e4b11f2124701..98304757cae91d22466ed25f8c6e36ce90a848db 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -58,12 +58,13 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: + case HloOpcode::kSort: case HloOpcode::kFusion: return CallContext::kParallel; default: @@ -236,6 +237,41 @@ void CallGraph::SetCallContexts() { } } +void CallGraph::SetNodeDepths() { + std::queue worklist; + + // Initialize node depths to -1. + for (CallGraphNode& node : nodes_) { + node.set_depth(-1); + } + + // Initialize worklist with all roots of the call graph (computations without + // callers). + for (const HloComputation* computation : module_->computations()) { + CallGraphNode& node = GetNode(computation); + if (node.callers().empty()) { + node.set_depth(0); + worklist.push(&node); + } + } + + while (!worklist.empty()) { + CallGraphNode* node = worklist.front(); + worklist.pop(); + for (const HloComputation* callee : node->callees()) { + CallGraphNode& callee_node = GetNode(callee); + if (callee_node.depth() < node->depth() + 1) { + callee_node.set_depth(node->depth() + 1); + worklist.push(&callee_node); + } + } + } + + for (CallGraphNode& node : nodes_) { + CHECK_NE(node.depth(), -1); + } +} + /* static */ std::unique_ptr CallGraph::Build(const HloModule* module) { // Constructor for CallGraph is private so absl::make_unique can't be used. @@ -271,6 +307,8 @@ std::unique_ptr CallGraph::Build(const HloModule* module) { } call_graph->SetCallContexts(); + call_graph->SetNodeDepths(); + XLA_VLOG_LINES(1, call_graph->ToString()); return call_graph; @@ -352,15 +390,38 @@ CallGraph::NearestAncestorsInSameComputation(HloInstruction* a, // Iterate through the callee->caller chains and find the earliest common // element. - for (HloInstruction* a_ancestor = a; a_ancestor != nullptr; - a_ancestor = next_caller(a_ancestor)) { - for (HloInstruction* b_ancestor = b; b_ancestor != nullptr; - b_ancestor = next_caller(b_ancestor)) { - if (a_ancestor->parent() == b_ancestor->parent()) { - return {a_ancestor, b_ancestor}; + HloInstruction* a_ancestor = a; + HloInstruction* b_ancestor = b; + int a_depth = GetNode(a->parent()).depth(); + int b_depth = GetNode(b->parent()).depth(); + + // Advance a_ancestor (b_ancestor) up the call chain until the call depth of + // a_ancestor or b_ancestor are the same. Necessarily each call to next_caller + // reduces the depth by exactly one. + if (a_depth > b_depth) { + for (int i = 0; i < a_depth - b_depth; ++i) { + a_ancestor = next_caller(a_ancestor); + if (a_ancestor == nullptr) { + return {nullptr, nullptr}; + } + } + } else if (b_depth > a_depth) { + for (int i = 0; i < b_depth - a_depth; ++i) { + b_ancestor = next_caller(b_ancestor); + if (b_ancestor == nullptr) { + return {nullptr, nullptr}; } } } + + while ((a_ancestor != nullptr) && (b_ancestor != nullptr)) { + if (a_ancestor->parent() == b_ancestor->parent()) { + return {a_ancestor, b_ancestor}; + } + + a_ancestor = next_caller(a_ancestor); + b_ancestor = next_caller(b_ancestor); + } return {nullptr, nullptr}; } diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 05c7c998738f861ee804d1ec87bfa5fb17ddfb74..57a636fd740995d6cce933fe19d5592a64bde5cf 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -30,7 +30,7 @@ namespace xla { // The context in which a computation is called by another computation. enum class CallContext { - // In a parallel contex the computation is applied to each element of the + // In a parallel context the computation is applied to each element of the // array argument(s). kMap and kReduce instructions call computations in // parallel context. kParallel, @@ -121,6 +121,11 @@ class CallGraphNode { // Returns the context in which this computation is called. CallContext context() const { return context_; } + // Returns the depth of this node in the call graph. The depth is defined as + // the length of the longest call chain from a computation with no callers + // (usually the entry computation node) to this node. + int depth() const { return depth_; } + string ToString() const; private: @@ -130,6 +135,9 @@ class CallGraphNode { // Sets the context in which this computation is called. void set_context(CallContext value) { context_ = value; } + // Sets the depth of this node in the graph. + void set_depth(int value) { depth_ = value; } + // Adds a callsite which calls this computation. Updates callers to include // the calling computation. void AddCallerCallSite(const CallSite& caller_callsite); @@ -164,6 +172,9 @@ class CallGraphNode { // The context in which this computation is called. CallContext context_ = CallContext::kNone; + + // The depth of this node in the call graph. + int depth_ = 0; }; // The call graph for an HLO module. The graph includes a node for each @@ -245,9 +256,16 @@ class CallGraph { private: CallGraph(const HloModule* module); + // Not copyable. + CallGraph(const CallGraph&) = delete; + CallGraph& operator=(const CallGraph&) = delete; + // Sets the call contexts for every node in the graph. void SetCallContexts(); + // Sets the call node depths for every node in the graph. + void SetNodeDepths(); + // Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS // post order (callee before caller) calling visitor_func on each node. Adds // nodes to 'visited' as each node is visited. Skips nodes already in diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index a3ac2568b0f3eec8556a42dbe3c2c64bd8564468..5de724f8924b78008ba4c56603b61bf93fbc5e7c 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -102,6 +102,7 @@ TEST_F(CallGraphTest, SingletonComputation) { const CallGraphNode& node = call_graph->GetNode(computation); EXPECT_EQ(computation, node.computation()); + EXPECT_EQ(node.depth(), 0); EXPECT_TRUE(node.callsites().empty()); EXPECT_TRUE(node.callees().empty()); EXPECT_TRUE(node.caller_callsites().empty()); @@ -122,11 +123,13 @@ TEST_F(CallGraphTest, UnreachableComputation) { EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_node.depth(), 0); EXPECT_EQ(entry_computation, entry_node.computation()); EXPECT_EQ(CallContext::kSequential, entry_node.context()); const CallGraphNode& unreachable_node = call_graph->GetNode(unreachable_computation); + EXPECT_EQ(unreachable_node.depth(), 0); EXPECT_EQ(unreachable_computation, unreachable_node.computation()); EXPECT_EQ(CallContext::kSequential, unreachable_node.context()); } @@ -145,6 +148,7 @@ TEST_F(CallGraphTest, ParallelComputation) { const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(entry_node.depth(), 0); EXPECT_EQ(CallContext::kSequential, entry_node.context()); EXPECT_EQ(5, entry_node.callsites().size()); EXPECT_EQ(1, entry_node.callees().size()); @@ -153,6 +157,7 @@ TEST_F(CallGraphTest, ParallelComputation) { const CallGraphNode& map_node = call_graph->GetNode(map_computation); EXPECT_EQ(map_computation, map_node.computation()); + EXPECT_EQ(map_node.depth(), 1); EXPECT_EQ(CallContext::kParallel, map_node.context()); EXPECT_TRUE(map_node.callsites().empty()); EXPECT_TRUE(map_node.callees().empty()); @@ -234,6 +239,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { EXPECT_EQ(entry_node.GetCallSite(map), &map_callsite); const CallGraphNode& sub_node = call_graph->GetNode(subcomputation); + EXPECT_EQ(sub_node.depth(), 1); EXPECT_EQ(CallContext::kBoth, sub_node.context()); } @@ -264,6 +270,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) { EXPECT_EQ(3, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_node.depth(), 0); EXPECT_EQ(entry_computation, entry_node.computation()); EXPECT_EQ(1, entry_node.callsites().size()); @@ -275,11 +282,13 @@ TEST_F(CallGraphTest, ComputationWithConditional) { EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite); const CallGraphNode& true_node = call_graph->GetNode(true_computation); + EXPECT_EQ(true_node.depth(), 1); EXPECT_TRUE(true_node.callees().empty()); EXPECT_EQ(1, true_node.callers().size()); EXPECT_EQ(entry_computation, true_node.callers()[0]); const CallGraphNode& false_node = call_graph->GetNode(false_computation); + EXPECT_EQ(false_node.depth(), 1); EXPECT_TRUE(false_node.callees().empty()); EXPECT_EQ(1, false_node.callers().size()); EXPECT_EQ(entry_computation, false_node.callers()[0]); @@ -332,9 +341,21 @@ TEST_F(CallGraphTest, ComplexGraph) { EXPECT_EQ(5, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + const CallGraphNode& a_node = call_graph->GetNode(a_computation); + const CallGraphNode& b_node = call_graph->GetNode(b_computation); + const CallGraphNode& c_node = call_graph->GetNode(c_computation); + const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); + + // Verify depths. + EXPECT_EQ(entry_node.depth(), 0); + EXPECT_EQ(a_node.depth(), 1); + EXPECT_EQ(b_node.depth(), 2); + EXPECT_EQ(c_node.depth(), 3); + EXPECT_EQ(cond_node.depth(), 2); + // Entry computation has one while instruction calling two computations // (cond_computation and a_computation). - const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); ASSERT_EQ(1, entry_node.callsites().size()); const std::vector& called_computations = entry_node.callsites()[0].called_computations(); @@ -342,7 +363,6 @@ TEST_F(CallGraphTest, ComplexGraph) { UnorderedElementsAre(cond_computation, a_computation)); EXPECT_EQ(CallContext::kSequential, entry_node.context()); - const CallGraphNode& c_node = call_graph->GetNode(c_computation); EXPECT_TRUE(c_node.callsites().empty()); EXPECT_THAT(c_node.callers(), UnorderedElementsAre(a_computation, b_computation)); @@ -364,7 +384,7 @@ TEST_F(CallGraphTest, ComplexGraph) { // Verify visitation order of some computations in the graph. auto index_of = [&visited](const HloComputation* comp) { - auto it = std::find(visited.begin(), visited.end(), comp); + auto it = absl::c_find(visited, comp); EXPECT_NE(it, visited.end()); return std::distance(visited.begin(), it); }; diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index 3c2d1ae6d82ebc6c10d52194fd1cec5e291025f7..b517495f2ea0c75679685c67f757ff586f8c79e3 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -72,7 +72,7 @@ ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) { } Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { - if (opaque_to_channel_.count(handle.handle()) == 0) { + if (!opaque_to_channel_.contains(handle.handle())) { return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; @@ -94,7 +94,7 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { } Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { - if (opaque_to_channel_.count(handle.handle()) == 0) { + if (!opaque_to_channel_.contains(handle.handle())) { return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index 52037bf9b52556c6aa2e66dd3209e25cf085cfe3..89e17eba36f23077ce4cf0704e7455b76bee68d1 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status.h" @@ -83,7 +84,8 @@ class ChannelTracker { // Mapping from ChannelHandle value to the corresponding registered // Channel object. - std::map opaque_to_channel_ GUARDED_BY(channel_mutex_); + absl::flat_hash_map opaque_to_channel_ + GUARDED_BY(channel_mutex_); TF_DISALLOW_COPY_AND_ASSIGN(ChannelTracker); }; diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 8f08c244908efb823b3870c19bdc3491fa87d44f..653f4555a77cc82e91fb1cd26206b93826375732 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -98,10 +98,17 @@ Compiler::GetPlatformCompilers() { auto* factories = GetPlatformCompilerFactories(); auto it = factories->find(platform->id()); if (it == factories->end()) { + string hint; + if (platform->Name() == "Host") { + hint = " (hint: try linking in tensorflow/compiler/jit:xla_cpu_jit)"; + } else if (platform->Name() == "CUDA") { + hint = " (hint: try linking in tensorflow/compiler/jit:xla_gpu_jit)"; + } + return NotFound( "could not find registered compiler for platform %s -- check " - "target linkage", - platform->Name()); + "target linkage%s", + platform->Name(), hint); } // And then we invoke the factory, placing the result into the mapping. diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index efc893818d03a20d6bd65b7dc1da72ea5da5ceb0..92d1ca4ba5da802a5f1c544017ac52dda38e9b1d 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -42,8 +42,8 @@ void ComputationLayout::SetToDefaultLayout() { } bool ComputationLayout::LayoutIsSet() const { - return std::all_of(parameter_layouts_.begin(), parameter_layouts_.end(), - [](const ShapeLayout& s) { return s.LayoutIsSet(); }) && + return absl::c_all_of(parameter_layouts_, + [](const ShapeLayout& s) { return s.LayoutIsSet(); }) && result_layout_.LayoutIsSet(); } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc similarity index 68% rename from tensorflow/compiler/xla/service/convolution_feature_group_converter.cc rename to tensorflow/compiler/xla/service/convolution_group_converter.cc index 95c7724c3c93507ae61a984301ecfc0111bef192..f11f9e5fc2949a92f83ff66506a9b162ffda1c92 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" +#include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include #include @@ -50,8 +50,12 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { Status HandleConvolution(HloInstruction* convolution) override; + Status HandleBatchGroupCount(HloInstruction* convolution); + // Runs the visitor on a computation. static bool Run(HloComputation* computation, + std::function is_cost_viable, + bool convert_batch_groups_only, bool canonicalize_depthwise_filter); // Returns whether any convolution ops were rewritten. @@ -60,10 +64,15 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { ~ConvolutionVisitor() override = default; private: - explicit ConvolutionVisitor(HloComputation* computation, - bool canonicalize_depthwise_filter = false) + explicit ConvolutionVisitor( + HloComputation* computation, + std::function is_cost_viable, + bool convert_batch_groups_only, + bool canonicalize_depthwise_filter = false) : computation_(computation), - filter_expansion_(!canonicalize_depthwise_filter) {} + filter_expansion_(!canonicalize_depthwise_filter), + convert_batch_groups_only_(convert_batch_groups_only), + is_cost_viable_(is_cost_viable) {} // Current HloComputation instance the ConvolutionVisitor is traversing. HloComputation* computation_; @@ -73,11 +82,21 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { // Whether filter expansion is required. bool filter_expansion_; + + // Decides whether to convert batch groups or feature groups. + bool convert_batch_groups_only_; + + // std::function(int64, int64)> chunk_fetcher + std::function is_cost_viable_; }; -bool ConvolutionVisitor::Run(HloComputation* computation, - bool canonicalize_depthwise_filter) { - ConvolutionVisitor visitor(computation, canonicalize_depthwise_filter); +bool ConvolutionVisitor::Run( + HloComputation* computation, + std::function is_cost_viable, + bool convert_batch_groups_only, bool canonicalize_depthwise_filter) { + ConvolutionVisitor visitor(computation, is_cost_viable, + convert_batch_groups_only, + canonicalize_depthwise_filter); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -176,18 +195,143 @@ HloInstruction* GetExpandedFilterMask( predicate_shape, HloOpcode::kEq, broadcasted_mask1, broadcasted_mask2)); } -Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { - int64 group_count = convolution->feature_group_count(); - if (group_count == 1) { +// This function handles batch_group_counts which are relevant only for +// depthwise backprop filter convolutions. +Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { + auto dim_numbers = convolution->convolution_dimension_numbers(); + auto activation = convolution->mutable_operand(0); + auto filter = convolution->mutable_operand(1); + int64 batch_group_count = convolution->batch_group_count(); + + if (batch_group_count == 1) { return Status::OK(); } - auto filter = convolution->mutable_operand(1); - changed_ = true; + + VLOG(2) << "Dealing with batch_group_count " << batch_group_count + << " for convolution " << convolution->ToString() << "\n"; + + auto add = [&](std::unique_ptr inst) { + return computation_->AddInstruction(std::move(inst)); + }; + + int64 input_batch_dimension = dim_numbers.input_batch_dimension(); + int64 output_batch_dimension = dim_numbers.output_batch_dimension(); + int64 output_feature_dimension = dim_numbers.output_feature_dimension(); + + int64 input_batch = activation->shape().dimensions(input_batch_dimension); + + // We are not yet supporting batch_group of sizes greater than 1. + TF_RET_CHECK(input_batch == batch_group_count); + + if (!is_cost_viable_(convolution) || filter_expansion_) { + // We first obtain the expanded the filter (which is the convolution + // output). The batch dimension is the expanded one (which originally + // represents kernel input feature dimension). We mask the filter to zero + // out the expanded regions. Next we reduce the filter in the batch + // dimension to obtain the original filter size. + + HloInstruction* filter_mask = + GetExpandedFilterMask(convolution->shape(), output_batch_dimension, + output_feature_dimension, batch_group_count, add); + auto expanded_filter_shape = ExpandedFilterShape( + convolution->shape(), batch_group_count, output_batch_dimension); + + auto new_convolution = add(HloInstruction::CreateConvolve( + expanded_filter_shape, activation, filter, + /*feature_group_count=*/1, /*batch_group_count=*/1, + convolution->window(), dim_numbers, convolution->precision_config())); + + auto zero = add(HloInstruction::CreateConstant( + LiteralUtil::Zero(expanded_filter_shape.element_type()))); + auto zero_filter = + add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); + + auto new_filter = add(HloInstruction::CreateTernary( + expanded_filter_shape, HloOpcode::kSelect, filter_mask, new_convolution, + zero_filter)); + + PrimitiveType reduce_type = new_filter->shape().element_type(); + auto reduce_window_shape = new_convolution->shape(); + reduce_window_shape.set_dimensions(output_batch_dimension, 1); + + // Ensure that data input to reduce window uses at least 32 bits. + if (primitive_util::BitWidth(reduce_type) < primitive_util::BitWidth(F32)) { + reduce_type = F32; + reduce_window_shape.set_element_type(F32); + Shape convert_shape = new_filter->shape(); + convert_shape.set_element_type(F32); + new_filter = + add(HloInstruction::CreateConvert(convert_shape, new_filter)); + } + + auto zero_literal = LiteralUtil::Zero(reduce_type); + auto zero_scalar = + add(HloInstruction::CreateConstant(std::move(zero_literal))); + + auto reduce_function = [&]() -> HloComputation* { + HloComputation::Builder b("add_computation"); + Shape shape = ShapeUtil::MakeShape(reduce_type, {}); + auto lhs = + b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); + auto rhs = + b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs")); + auto scalar_op = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs, rhs)); + return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + }; + + // Create the reduce window. + Window window; + for (int64 i = 0; i < new_convolution->shape().dimensions_size(); ++i) { + auto* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + if (i == output_batch_dimension) { + dim->set_stride(batch_group_count); + dim->set_size(batch_group_count); + } else { + dim->set_stride(1); + dim->set_size(1); + } + } + auto reduce_window = add(HloInstruction::CreateReduceWindow( + reduce_window_shape, new_filter, zero_scalar, window, + reduce_function())); + + Shape convert_back_shape = reduce_window->shape(); + convert_back_shape.set_element_type(activation->shape().element_type()); + + // Convert reduced data back to the original data type. + auto reduce_window_converted = + HloInstruction::CreateConvert(convert_back_shape, reduce_window); + + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(reduce_window_converted))); + changed_ = true; + } + + return Status::OK(); +} + +Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { + if (convert_batch_groups_only_) { + return HandleBatchGroupCount(convolution); + } + auto add = [&](std::unique_ptr inst) { return computation_->AddInstruction(std::move(inst)); }; + int64 group_count = convolution->feature_group_count(); + if (group_count == 1) { + return Status::OK(); + } + + changed_ = true; auto dim_numbers = convolution->convolution_dimension_numbers(); + auto filter = convolution->mutable_operand(1); int64 kernel_input_feature_dim = dim_numbers.kernel_input_feature_dimension(); int64 group_size = filter->shape().dimensions(kernel_input_feature_dim); int64 kernel_output_feature_dim = @@ -205,6 +349,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { // If the code generator handles depthwise separable convolutions // inherently, then no filter expansion is needed. if (!filter_expansion_ && depthwise_separable) { + changed_ = false; return Status::OK(); } // We want to repeat 'filter' in the 'input_feature_dim' dimension @@ -233,8 +378,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { auto new_convolution = HloInstruction::CreateConvolve( convolution->shape(), convolution->mutable_operand(0), new_filter, - /*feature_group_count=*/1, convolution->window(), dim_numbers, - convolution->precision_config()); + /*feature_group_count=*/1, /*batch_group_count=*/1, + convolution->window(), dim_numbers, convolution->precision_config()); TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( convolution, std::move(new_convolution))); } else { @@ -294,8 +439,9 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { dim->set_size(group_size); auto new_convolution = add(HloInstruction::CreateConvolve( - new_output_shape, activation, filter, group_count, new_window, - dim_numbers, convolution->precision_config())); + new_output_shape, activation, filter, group_count, + /*batch_group_count=*/1, new_window, dim_numbers, + convolution->precision_config())); // Delete the extra spatial dimension, and reshape. Shape reshaped_convolution_shape = @@ -372,7 +518,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { auto new_convolution = add(HloInstruction::CreateConvolve( conv_slice_shape, activation_slice, filter_slice, - /*feature_group_count=*/1, convolution->window(), dim_numbers, + /*feature_group_count=*/1, /*batch_group_count=*/1, + convolution->window(), dim_numbers, convolution->precision_config())); sliced_convolutions.push_back(new_convolution); @@ -390,17 +537,19 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { } // namespace -StatusOr ConvolutionFeatureGroupConverter::Run(HloModule* module) { - XLA_VLOG_LINES(2, "ConvolutionFeatureGroupConverter::Run(), before:\n" + - module->ToString()); +StatusOr ConvolutionGroupConverter::Run(HloModule* module) { + XLA_VLOG_LINES( + 2, "ConvolutionGroupConverter::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (ConvolutionVisitor::Run(comp, filter_expansion_)) { + if (ConvolutionVisitor::Run(comp, is_cost_viable_, + convert_batch_groups_only_, + filter_expansion_)) { changed = true; } } - XLA_VLOG_LINES(2, "ConvolutionFeatureGroupConverter::Run(), after:\n" + - module->ToString()); + XLA_VLOG_LINES( + 2, "ConvolutionGroupConverter::Run(), after:\n" + module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_group_converter.h similarity index 58% rename from tensorflow/compiler/xla/service/convolution_feature_group_converter.h rename to tensorflow/compiler/xla/service/convolution_group_converter.h index cb6bc04c00a2ff10f970da2a07fb540a561dad5a..1caf1841119a965044502435fe0f5b38ca94f6a5 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h +++ b/tensorflow/compiler/xla/service/convolution_group_converter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -25,23 +25,34 @@ namespace xla { // A pass which rewrites convolutions with feature_group_count > 1 into // convolutions with feature_group_count = 1. -class ConvolutionFeatureGroupConverter : public HloModulePass { +class ConvolutionGroupConverter : public HloModulePass { public: - ConvolutionFeatureGroupConverter(bool canonicalize_depthwise_filter = false) - : filter_expansion_(canonicalize_depthwise_filter) {} + ConvolutionGroupConverter(std::function is_cost_viable, + bool convert_batch_groups_only, + bool canonicalize_depthwise_filter = false) + : is_cost_viable_(is_cost_viable), + convert_batch_groups_only_(convert_batch_groups_only), + filter_expansion_(canonicalize_depthwise_filter) {} absl::string_view name() const override { - return "convolution-feature-group-converter"; + return "convolution-group-converter"; } // Run convolution rewriting on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + // Lambda containing cost model that decides whether to expand + // batch_group_count. + std::function is_cost_viable_; + + // Decides whether to convert batch groups or feature groups. + bool convert_batch_groups_only_; + // Tells whether filter expansion is required. bool filter_expansion_; }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc similarity index 68% rename from tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc rename to tensorflow/compiler/xla/service/convolution_group_converter_test.cc index e6bf2143a21bd5001d3530fe8727c88504be1d43..9cee3eda95252d6c7d725fbb03030bd58f52e71f 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" +#include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include #include @@ -30,10 +30,10 @@ limitations under the License. namespace xla { namespace { -using ConvolutionFeatureGroupConverterTest = HloTestBase; +using ConvolutionGroupConverterTest = HloTestBase; namespace op = testing::opcode_matchers; -TEST_F(ConvolutionFeatureGroupConverterTest, +TEST_F(ConvolutionGroupConverterTest, ConvertFeatureGroupCountEqualToInputFeatureDim) { string hlo_string = R"(HloModule Convolve1D1Window_0_module @@ -49,7 +49,8 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,2], filter: f32[1,1,2]) -> f32[1,2 auto computation = module->entry_computation(); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); - ConvolutionFeatureGroupConverter converter; + ConvolutionGroupConverter converter(nullptr, /*convert_batch_groups_only=*/ + false); ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); // Make sure the convolution is converted to one with feature_group_count = 1. @@ -63,7 +64,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,2], filter: f32[1,1,2]) -> f32[1,2 op::Broadcast(op::Constant()))); } -TEST_F(ConvolutionFeatureGroupConverterTest, +TEST_F(ConvolutionGroupConverterTest, ConvertFeatureGroupCountDivisorOfInputFeatureDim) { string hlo_string = R"(HloModule Convolve1D1Window_0_module @@ -79,7 +80,8 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2 auto computation = module->entry_computation(); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); - ConvolutionFeatureGroupConverter converter; + ConvolutionGroupConverter converter(nullptr, /*convert_batch_groups_only=*/ + false); ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); // Make sure the convolution is replaced with a concatenate. @@ -92,5 +94,32 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2 EXPECT_EQ(root->operand(1)->feature_group_count(), 1); } +TEST_F(ConvolutionGroupConverterTest, + ConvertBatchGroupCountEqualToInputBatchDim) { + string hlo_string = R"(HloModule Convolve1D1Window_0_module + +ENTRY %Convolve1D1Window_0.v3 (input: f32[16,19,19,512]{3,2,1,0}, filter: f32[16,19,19,512]{3,2,1,0}) -> f32[3,3,512,1]{3,2,1,0} { + %input = f32[16,19,19,512]{3,2,1,0} parameter(0) + %filter = f32[16,19,19,512]{3,2,1,0} parameter(1) + ROOT %convolution = f32[3,3,512,1]{3,2,1,0} convolution(f32[16,19,19,512]{3,2,1,0} %input, f32[16,19,19,512]{3,2,1,0} %filter), window={size=19x19 pad=1_1x1_1}, dim_labels=f01b_i01o->01fb, batch_group_count=512 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + auto cost_model = [](HloInstruction* conv) { return false; }; + ConvolutionGroupConverter converter(cost_model, /*convert_batch_groups_only=*/ + true); + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + + // Verify that the convolution is replaced by a convert. + EXPECT_EQ(root->opcode(), HloOpcode::kConvert); + // Make sure the convert is being fed by a reduce window. + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kReduceWindow); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index df6059663876dfde71f4c75d3931b3d2de72c1df..5e26a63cebfa9b2e50f4b13335c10c246999d4df 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -349,11 +349,12 @@ Status AddCopiesForAliasedInputOutputs(HloModule* module) { ShapeTree param_indices_to_copy(param->shape()); module->input_output_alias_config().ForEachAlias( - [&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { - if (param_number == param->parameter_number()) { + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { + if (alias.parameter_number == param->parameter_number()) { param_has_alias = true; - *(param_indices_to_copy.mutable_element(param_index)) = true; + *(param_indices_to_copy.mutable_element(alias.parameter_index)) = + true; *(output_indices_to_copy.mutable_element(output_index)) = true; } }); @@ -395,13 +396,14 @@ Status AddCopiesForAliasedInputOutputs(HloModule* module) { // Add control dependencies between the input/output copies. TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus( - [&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& input_index) -> Status { - if (!copied_parameters[param_number]) { + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) -> Status { + if (!copied_parameters[alias.parameter_number]) { return Status::OK(); } HloInstruction* from = - copied_parameters[param_number]->element(input_index); + copied_parameters[alias.parameter_number]->element( + alias.parameter_index); HloInstruction* to = output_copy_tree.element(output_index); TF_RET_CHECK(from != nullptr); @@ -522,7 +524,7 @@ class CopyRemover { // between copies added around aliased operations (kWhile) guarantees // this strict order. for (const HloValue* value_a : buffer.values()) { - if (ShapeUtil::IsToken(value_a->shape())) { + if (value_a->shape().IsToken()) { // Token values have no representation and cannot interfere. continue; } @@ -539,10 +541,9 @@ class CopyRemover { } std::vector values = buffer.values(); - std::sort(values.begin(), values.end(), - [this](const HloValue* a, const HloValue* b) { - return ordering_.IsDefinedBefore(*a, *b); - }); + absl::c_sort(values, [this](const HloValue* a, const HloValue* b) { + return ordering_.IsDefinedBefore(*a, *b); + }); // Create a list containing all of the values in the buffer. AddValueList(values, &value_to_node); @@ -842,12 +843,11 @@ class CopyRemover { copy_value_node->next->prev = operand_node; // Patch up uses. Remove use of copy from operand_node uses. - auto it = - std::find_if(operand_node->uses.begin(), operand_node->uses.end(), - [copy_value_node](const HloUse* use) { - return use->instruction == - copy_value_node->value->defining_instruction(); - }); + auto it = absl::c_find_if( + operand_node->uses, [copy_value_node](const HloUse* use) { + return use->instruction == + copy_value_node->value->defining_instruction(); + }); CHECK(it != operand_node->uses.end()); operand_node->uses.erase(it); diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index e4e9d7ba05c115be9dd0eb53ebd7de208d514efb..4391bdcba532661a0fde789e2c4ed324c40bcd32 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1376,9 +1376,11 @@ TEST_F(CopyInsertionTest, CrossingParameters) { builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 4); @@ -1409,9 +1411,11 @@ TEST_F(CopyInsertionTest, ParametersAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); @@ -1475,7 +1479,8 @@ TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -1516,7 +1521,8 @@ TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); @@ -1557,7 +1563,8 @@ TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({add, negate1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); @@ -1848,8 +1855,7 @@ ENTRY %TokensShouldNotBeCopied () -> s32[] { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - HloRunner::CreateModuleFromString( - module_string, GetDebugOptionsForTest())); + ParseAndReturnVerifiedModule(module_string)); InsertCopies(module.get()); // There should be no copies added because tokens should not be copied. @@ -2112,8 +2118,7 @@ ENTRY TestComputation { ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 } )"; - auto module_or_status = - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); auto module = module_or_status.ConsumeValueOrDie(); InsertCopies(module.get()); } @@ -2213,8 +2218,7 @@ ENTRY TestComputation { ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 } )"; - auto module_or_status = - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); auto module = module_or_status.ConsumeValueOrDie(); InsertCopies(module.get()); } @@ -2231,7 +2235,7 @@ cond.inner { body.inner { param.body.inner = pred[] parameter(0) - ROOT neg = pred[] negate(param.body.inner) + ROOT not = pred[] not(param.body.inner) } cond.outer { @@ -2248,9 +2252,8 @@ ENTRY TestComputation { ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer } )"; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); InsertCopies(module.get()); // There should only be a single copy inserted, and it's in the entry diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index ce4c2a9cc69240b9565b35a3f2504d7fc9373917..42672bc3875af2d732d80691df6bf85b9d8080cd 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -1,6 +1,14 @@ # Description: # LLVM-based CPU backend for XLA. +load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") +load( + "//third_party/mkl:build_defs.bzl", + "mkl_deps", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") +load(":build_defs.bzl", "runtime_copts") + licenses(["notice"]) # Apache 2.0 package( @@ -14,15 +22,6 @@ package_group( ], ) -load(":build_defs.bzl", "runtime_copts") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") -load( - "//third_party/mkl:build_defs.bzl", - "mkl_deps", -) - # Filegroup used to collect source files for dependency checking. filegroup( name = "c_srcs", @@ -95,6 +94,7 @@ cc_library( ":target_machine_features", "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:scatter_expander", @@ -112,8 +112,9 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", - "//tensorflow/compiler/xla/service:convolution_feature_group_converter", + "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -133,7 +134,9 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:sort_simplifier", "//tensorflow/compiler/xla/service:transpose_folding", + "//tensorflow/compiler/xla/service:triangular_solve_expander", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", @@ -241,6 +244,7 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor/host:host_stream", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -364,15 +368,33 @@ cc_library( ], ) +cc_library( + name = "tiled_dot_emitter", + srcs = ["tiled_dot_emitter.cc"], + hdrs = ["tiled_dot_emitter.h"], + deps = [ + ":vector_support_library", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + cc_library( name = "dot_op_emitter", srcs = ["dot_op_emitter.cc"], - hdrs = ["dot_op_emitter.h"], + hdrs = [ + "dot_op_emitter.h", + ], deps = [ ":cpu_options", ":cpu_runtime", ":ir_emission_utils", ":target_machine_features", + ":tiled_dot_emitter", ":vector_support_library", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -380,6 +402,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", @@ -572,6 +595,7 @@ cc_library( ":runtime_matvec", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//third_party/eigen3", ], ) @@ -630,6 +654,7 @@ cc_library( deps = [ ":runtime_matvec", "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//third_party/eigen3", ], ) @@ -1005,7 +1030,6 @@ tf_cc_test( size = "small", srcs = ["cpu_eigen_tensor_alignment_test.cc"], deps = [ - ":dot_op_emitter", ":ir_emission_utils", ":target_machine_features_fake", "//tensorflow/compiler/xla:test", diff --git a/tensorflow/compiler/xla/service/cpu/build_defs.bzl b/tensorflow/compiler/xla/service/cpu/build_defs.bzl index e78330b21689fdd818cd97128bbcaaa9e0118602..ffa1cd4ec8e26e7dbe92e7b99cf65e99db5400b9 100644 --- a/tensorflow/compiler/xla/service/cpu/build_defs.bzl +++ b/tensorflow/compiler/xla/service/cpu/build_defs.bzl @@ -1,12 +1,11 @@ """build_defs for service/cpu.""" - def runtime_copts(): - """Returns copts used for CPU runtime libraries.""" - return (["-DEIGEN_AVOID_STL_ARRAY"] + select({ - "//tensorflow:android_arm": ["-mfpu=neon"], - "//conditions:default": [] - }) + select({ - "//tensorflow:android": ["-O2"], - "//conditions:default": [] - })) + """Returns copts used for CPU runtime libraries.""" + return (["-DEIGEN_AVOID_STL_ARRAY"] + select({ + "//tensorflow:android_arm": ["-mfpu=neon"], + "//conditions:default": [], + }) + select({ + "//tensorflow:android": ["-O2"], + "//conditions:default": [], + })) diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 796a7cf94d02b0ad42366387a9d3f8d589b8840a..414eacddfc7ba3c295c027c64c445a2046235d36 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -66,9 +66,14 @@ class FilteredPassManager : public llvm::legacy::PassManager { explicit FilteredPassManager(bool disable_expensive_passes) : disable_expensive_passes_(disable_expensive_passes) {} void add(llvm::Pass* p) override { + llvm::StringRef PassName = p->getPassName(); + if (PassName.contains("Warn about non-applied transformations")) { + delete p; + return; + } if (disable_expensive_passes_) { - llvm::StringRef PassName = p->getPassName(); if (PassName.contains("Unroll loops")) { + delete p; return; } } diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 2d9978404cc9ec1e40fc61aaf794a8f1f06050bb..8e55267a67d330e7e721f9b5fb25451357a49a9d 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -132,7 +132,8 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { HloInstruction* new_conv = module->entry_computation()->AddInstruction( HloInstruction::CreateConvolve( new_conv_shape, new_input, new_kernel, hlo->feature_group_count(), - hlo->window(), new_dnums, hlo->precision_config())); + hlo->batch_group_count(), hlo->window(), new_dnums, + hlo->precision_config())); // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index c58175428fea6a2d38253c35de598b99a4281bf1..02085108a081358cd4f8aed6dc12557cbd8eea85 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -84,8 +84,8 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}), - input, kernel, /*feature_group_count=*/1, conv_window_, dnums, - DefaultPrecisionConfig(2))); + input, kernel, /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window_, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -147,8 +147,8 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}), - input, kernel, /*feature_group_count=*/1, conv_window_, dnums, - DefaultPrecisionConfig(2))); + input, kernel, /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window_, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 6374822c81bf42fd12829f57cf93c19457128219..19ab3bddb567afeeddb7c01b9a847b51bea5d957 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -51,7 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" -#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" +#include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" @@ -69,6 +69,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -92,7 +93,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/scatter_expander.h" +#include "tensorflow/compiler/xla/service/sort_simplifier.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" +#include "tensorflow/compiler/xla/service/triangular_solve_expander.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" @@ -103,6 +106,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/dynamic_annotations.h" namespace xla { namespace cpu { @@ -244,6 +248,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( HloPassPipeline pipeline("HLO passes through layout assignment"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + pipeline.AddPass(); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( @@ -252,12 +257,23 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); + pipeline.AddPass(); + // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(/*decompose_batch_dot=*/false); + auto cost_model = [](HloInstruction* conv) { + // We need a cost model for CPUs. Currently, do nothing. + return false; + }; + pipeline.AddPass( + cost_model, + /*convert_batch_groups_only=*/true); + pipeline.AddPass( + cost_model, + /*convert_batch_groups_only=*/false); pipeline.AddPass(target_machine_features); { auto& pass = @@ -270,10 +286,10 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); pipeline.AddPass(); - AlgebraicSimplifierOptions options( - [](const Shape&, const Shape&) { return false; }); + AlgebraicSimplifierOptions options; options.set_enable_dot_strength_reduction(false); pass.AddPass(options); + pass.AddPass(); pass.AddPass(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO @@ -293,7 +309,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass( [&](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot, *target_machine_features) + return DotImplementationCanHandleTranspose(dot, + *target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, @@ -336,8 +353,7 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( pass.AddInvariantChecker( /*layout_sensitive=*/true, /*allow_mixed_precision=*/false); - AlgebraicSimplifierOptions options( - [](const Shape&, const Shape&) { return true; }); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); options.set_enable_dot_strength_reduction(false); pass.AddPass>(options); @@ -497,7 +513,7 @@ Status CreateHloProfilingArtifacts( auto shape_size_bytes = [](const Shape& shape) { // On the cpu, opaques are pointers. - if (ShapeUtil::IsOpaque(shape)) { + if (shape.IsOpaque()) { return static_cast(sizeof(void*)); } return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); @@ -621,7 +637,13 @@ StatusOr> CpuCompiler::RunBackend( IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - &target_machine_features); + &target_machine_features, +#ifdef MEMORY_SANITIZER + /*emit_code_for_msan=*/true +#else + /*emit_code_for_msan=*/false +#endif + ); TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); @@ -635,18 +657,17 @@ StatusOr> CpuCompiler::RunBackend( .EmitComputation( embedded_computation, embedded_computation->name(), /*is_top_level_computation=*/false, - &schedule.sequence(embedded_computation).instructions()) + schedule.sequence(embedded_computation).instructions()) .status()); } string function_name_prefix = entry_computation->name().empty() ? "__compute" : entry_computation->name(); - TF_ASSIGN_OR_RETURN( - llvm::Function * entry_function, - ir_emitter.EmitComputation( - entry_computation, function_name_prefix, - /*is_top_level_computation=*/true, - &schedule.sequence(entry_computation).instructions())); + TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, + ir_emitter.EmitComputation( + entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + schedule.sequence(entry_computation).instructions())); string function_name = [&]() { llvm::SmallVector function_name_vector; @@ -659,9 +680,9 @@ StatusOr> CpuCompiler::RunBackend( if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } - TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); + TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); @@ -820,7 +841,9 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, IrEmitter ir_emitter(*module, *assignment, &llvm_module, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - &target_machine_features); + &target_machine_features, + // TODO(b/66051036): Run full msan for AOT. + /*emit_code_for_msan=*/false); TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); @@ -835,7 +858,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, .EmitComputation( embedded_computation, embedded_computation->name(), /*is_top_level_computation=*/false, - &schedule.sequence(embedded_computation).instructions()) + schedule.sequence(embedded_computation).instructions()) .status()); } const string& entry_point_name = options.entry_point_name(); @@ -843,7 +866,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, ir_emitter.EmitComputation( computation, entry_point_name, /*is_top_level_computation=*/true, - &schedule.sequence(computation).instructions())); + schedule.sequence(computation).instructions())); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc index 8727c72b6e42517b1859e98ecadb41bbceed761c..485769a373acf5ae70c471b1a5dfcfb20ff772ef 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -28,37 +27,6 @@ namespace { class CpuEigenTensorAlignmentTest : public ::testing::Test {}; -TEST_F(CpuEigenTensorAlignmentTest, EigenDotAlignment) { - string hlo_string = R"( -HloModule DotOperation - -ENTRY DotOperation { - arg0 = f32[5,256] parameter(0) - arg1 = f32[256,1024] parameter(1) - ROOT dot = f32[5,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_string)); - - HloInstruction* dot = module->entry_computation()->root_instruction(); - - TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment( - [](int64 size) { return 1; }); - - EXPECT_FALSE( - PotentiallyImplementedAsEigenDot(*dot, target_machine_with_no_alignment)); - - TargetMachineFeaturesWithFakeAlignmentLogic - target_machine_with_full_alignment([](int64 size) { - return TargetMachineFeatures::kEigenExpectedTensorAlignment; - }); - - EXPECT_TRUE(PotentiallyImplementedAsEigenDot( - *dot, target_machine_with_full_alignment)); -} - TEST_F(CpuEigenTensorAlignmentTest, EigenConvAlignment) { string hlo_string = R"( HloModule ConvOperation diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 818b2b0d0db2893e11fa46c7867e6c74bbbb6905..23d0af34233858515af21df5e92346742a5b5dc3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -213,6 +213,8 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), run_options->allocator(), stream->parent()->device_ordinal()); + const HloInputOutputAliasConfig& input_output_alias = + module().input_output_alias_config(); // Move OwningDeviceMemory values which contain the array(s) of the result // into the respective location in ScopedShapedBuffer which is returned to the @@ -232,12 +234,31 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( TF_ASSIGN_OR_RETURN( const BufferAllocation::Slice slice, this->assignment_->GetUniqueSlice(src, buffer_source->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - const BufferAllocation::Index buffer_index = slice.index(); OwningDeviceMemory& buffer = buffers[buffer_index]; - CHECK(!buffer.is_null() || buffer.size() == 0); - *device_memory = buffer.Forget(); + if (!slice.allocation()->is_entry_computation_parameter()) { + // If the buffer coming out of the result is from a parameter, the + // owning buffer will be null, and that means the caller aliased some + // parameter buffer to an output one (via the + // HloInputOutputAliasConfig API). If that is the case, the caller + // will receive a partially complete scoped shaped buffer, which they + // will have to fill up on return. Unfortunately the interface to the + // execute APIs are ShapedBuffer pointer based, which assumes caller + // ownership, and hence a buffer coming from there cannot be part of + // the new ScopedShapedBuffer we create for the result (which assumes + // ownership). + *device_memory = buffer.Forget(); + } else { + auto output_alias = input_output_alias.GetAliasedOutput( + slice.allocation()->parameter_number(), + slice.allocation()->param_shape_index()); + CHECK(output_alias) + << "Ouput buffer is coming from parameter " + << slice.allocation()->parameter_number() << " at index " + << slice.allocation()->param_shape_index() + << ", but no alias exists"; + CHECK_EQ(*output_alias, index); + } return Status::OK(); })); return std::move(result_buffer); @@ -326,7 +347,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { // On the cpu, opaques are pointers. - if (ShapeUtil::IsOpaque(shape)) { + if (shape.IsOpaque()) { return sizeof(void*); } return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc index 7fbe0fa157c57eb0c274662a1de95cf5328ccfa8..4ac61f44d9f38425da2d1fc6b9495cb4deba5047 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 527df0bd1c23bba74f32226e5622fed32f7dcf84..c4bde837e57e82584c2a007858ed8d55608acd3c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -332,7 +332,7 @@ TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) { TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { HloComputation::Builder builder(TestName()); Shape param_shape = ShapeUtil::MakeShape(F32, {8}); - Shape starts_shape = ShapeUtil::MakeShape(F32, {2}); + Shape starts_shape = ShapeUtil::MakeShape(F32, {}); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8}); Shape reshape_shape = ShapeUtil::MakeShape(F32, {8, 8}); Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -340,13 +340,15 @@ TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { HloInstruction::CreateParameter(0, param_shape, "param")); HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, starts_shape, "starts")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); HloInstruction* broadcast2 = builder.AddInstruction( HloInstruction::CreateBroadcast(broadcast_shape, param0, {1})); HloInstruction* reshape3 = builder.AddInstruction( HloInstruction::CreateReshape(reshape_shape, broadcast2)); HloInstruction* dynamic_slice4 = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - dynamic_slice_shape, reshape3, param1, {4, 4})); + dynamic_slice_shape, reshape3, {param1, param2}, {4, 4})); builder.AddInstruction(HloInstruction::CreateUnary( dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4)); @@ -356,7 +358,8 @@ TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { RunFusionAndCheckOpcodesWereFused( module.get(), {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kReshape, - HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter}); + HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kParameter}); } TEST_F(OpcodeFusionTest, Broadcast_Negate) { @@ -381,14 +384,14 @@ TEST_F(OpcodeFusionTest, Broadcast_Negate) { TEST_F(OpcodeFusionTest, DynamicSlice_Negate) { HloComputation::Builder builder(TestName()); Shape param_shape = ShapeUtil::MakeShape(F32, {4}); - Shape slice_shape = ShapeUtil::MakeShape(F32, {1}); + Shape slice_shape = ShapeUtil::MakeShape(F32, {}); Shape result_shape = ShapeUtil::MakeShape(F32, {2}); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "param")); HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, slice_shape, "starts")); HloInstruction* dynamic_slice2 = builder.AddInstruction( - HloInstruction::CreateDynamicSlice(result_shape, param0, param1, {2})); + HloInstruction::CreateDynamicSlice(result_shape, param0, {param1}, {2})); builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, dynamic_slice2)); @@ -548,28 +551,36 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000}); + std::vector slice_indices, update_indices; + for (int i = 0; i < 3; ++i) { + slice_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + 1 + i, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + update_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + 5 + i, ShapeUtil::MakeShape(U32, {}), "update_indices"))); + } HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( slice_shape, builder.AddInstruction( HloInstruction::CreateParameter(0, full_shape, "slice_from")), - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), + slice_indices, /*slice_sizes=*/{10, 1, 1000})); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( full_shape, builder.AddInstruction( - HloInstruction::CreateParameter(2, full_shape, "to_update")), - slice, - builder.AddInstruction(HloInstruction::CreateParameter( - 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); + HloInstruction::CreateParameter(4, full_shape, "to_update")), + slice, update_indices)); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( - module.get(), {HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, - HloOpcode::kParameter, HloOpcode::kParameter, - HloOpcode::kParameter, HloOpcode::kParameter}); + module.get(), + {HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, + HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter}); } TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { @@ -578,49 +589,40 @@ TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50}); - auto loop_idx = builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(S32, {1}), - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(S32, {}), "param0")))); - + auto loop_idx = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(S32, {}), "param0")); auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(S32, {1}), "param1")); - auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(S32, {5}), - {loop_idx, param1, param1, param1, param1}, /*dimension=*/0)); + 1, ShapeUtil::MakeShape(S32, {}), "param1")); - auto idx_choice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ShapeUtil::MakeShape(S32, {1}), - builder.AddInstruction(HloInstruction::CreateParameter( - 2, ShapeUtil::MakeShape(S32, {4}), "param2")), - loop_idx, - /*slice_sizes=*/{1})); - - PaddingConfig padding_config; - padding_config.add_dimensions()->set_edge_padding_high(4); - auto pad = builder.AddInstruction(HloInstruction::CreatePad( - ShapeUtil::MakeShape(S32, {5}), idx_choice, - builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), - padding_config)); + auto idx_choice = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {}), + builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(S32, {1}), + builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(S32, {4}), "param2")), + {loop_idx}, + /*slice_sizes=*/{1})))); + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( ShapeUtil::MakeShape(F32, {1, 100, 10, 100, 50}), builder.AddInstruction(HloInstruction::CreateParameter( 3, ShapeUtil::MakeShape(F32, {100, 100, 10, 100, 50}), "param3")), - pad, /*slice_sizes=*/{1, 100, 10, 100, 50})); + {idx_choice, zero, zero, zero, zero}, + /*slice_sizes=*/{1, 100, 10, 100, 50})); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( full_shape, builder.AddInstruction( HloInstruction::CreateParameter(4, full_shape, "param4")), - slice, concat)); + slice, {loop_idx, param1, param1, param1, param1})); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( module.get(), - {HloOpcode::kConcatenate, HloOpcode::kPad, HloOpcode::kDynamicSlice, - HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, + {HloOpcode::kDynamicSlice, HloOpcode::kDynamicSlice, + HloOpcode::kDynamicUpdateSlice, HloOpcode::kReshape, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter}); } @@ -930,9 +932,10 @@ ENTRY main { return result; } -INSTANTIATE_TEST_CASE_P(GatherLoopFusionTestInstantiation, GatherLoopFusionTest, - ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()), - GatherLoopFusionTestSpec::Name); +INSTANTIATE_TEST_SUITE_P(GatherLoopFusionTestInstantiation, + GatherLoopFusionTest, + ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()), + GatherLoopFusionTestSpec::Name); } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index c291bf2d1ba2eaff4192051840768c037bece86f..95b8025f873c56bea063ff258d4abd6614257d85 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -46,8 +46,7 @@ static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) { for (auto* user : instruction->users()) { optional operand_idx = ProfitableToMakeDotOperandColumnMajor(*user); if (!operand_idx || user->operand(*operand_idx) != instruction || - std::count(user->operands().begin(), user->operands().end(), - instruction) != 1) { + absl::c_count(user->operands(), instruction) != 1) { return false; } } @@ -94,60 +93,38 @@ static Shape ColMajorShape(const Shape& old_shape) { return new_shape; } +static bool OperandsAndResultMustHaveRowMajorLayout( + const HloInstruction& instr, + const TargetMachineFeatures& target_machine_features) { + if (instr.opcode() == HloOpcode::kConvolution) { + return PotentiallyImplementedAsEigenConvolution(instr, + target_machine_features); + } else if (instr.opcode() == HloOpcode::kDot) { + return DotOperandsAndResultMustHaveRowMajorLayout(instr, + target_machine_features); + } + return false; +} + Status CpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { ShouldMakeOperandColMajorCache cache; const HloComputation* computation = constraints->computation(); for (auto* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction, - target_machine_features_)) { - const HloInstruction* convolution = instruction; - const HloInstruction* lhs_instruction = convolution->operand(0); - const HloInstruction* rhs_instruction = convolution->operand(1); - - // In order to implement `convolution` with Eigen convolution, the layouts - // of the input, filter, and output need to be row-major. - // - // These constraints are not hard constraints. Ideally, we should decide - // which layouts to choose according to some cost model. - Shape output_shape(RowMajorShape(convolution->shape())); - Shape input_shape(RowMajorShape(lhs_instruction->shape())); - Shape filter_shape(RowMajorShape(rhs_instruction->shape())); - - // Set layouts of the instructions' shapes. - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(input_shape, convolution, 0)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(filter_shape, convolution, 1)); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(output_shape, convolution)); + if (OperandsAndResultMustHaveRowMajorLayout(*instruction, + target_machine_features_)) { + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + RowMajorShape(instruction->shape()), instruction)); + for (int i = 0; i < instruction->operand_count(); i++) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + RowMajorShape(instruction->operand(i)->shape()), instruction, i)); + } } else if (optional op_idx = ShouldMakeOperandColumnMajor(&cache, *instruction)) { const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( ColMajorShape(op->shape()), instruction, *op_idx)); - } else if (PotentiallyImplementedAsEigenDot(*instruction, - target_machine_features_)) { - const HloInstruction* dot = instruction; - // In order to implement `dot` with Eigen dot, the layouts of the lhs, - // rhs, and output need to be row-major. - // - // These constraints are not hard constraints. Ideally, we should decide - // which layouts to choose according to some cost model. - Shape output_shape(RowMajorShape(dot->shape())); - - const HloInstruction* lhs_instruction = dot->operand(0); - Shape lhs_shape(RowMajorShape(lhs_instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); - - const HloInstruction* rhs_instruction = dot->operand(1); - Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); - - // Set layouts of the instructions' shapes. - TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(output_shape, dot)); } else { for (int64 operand_no = 0; operand_no < instruction->operand_count(); ++operand_no) { @@ -160,7 +137,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( continue; } // Skip operands with non-array shapes. - if (!ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + if (!instruction->operand(operand_no)->shape().IsArray()) { continue; } Shape operand_shape( @@ -175,7 +152,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( } // Skip instructions which don't produce array shapes (tuples, opaque, // etc.). - if (!ShapeUtil::IsArray(instruction->shape())) { + if (!instruction->shape().IsArray()) { continue; } } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index 92debb83e33b1400a59e5eef0f90971392ab7b22..ff654c83d61e7cc09ac7839feccaf2bc9cb3c63c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -23,8 +23,8 @@ namespace { const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; -const char* const kXlaEnableExperimentalLlvmIrGemm = - "xla_enable_experimental_llvm_ir_gemm"; +const char* const kXlaForceEnableExperimentalLlvmIrGemm = + "xla_force_enable_experimental_llvm_ir_gemm"; const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -57,10 +57,10 @@ absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config) { return absl::nullopt; } -bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { +bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { const auto& extra_options_map = config.debug_options().xla_backend_extra_options(); - return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; + return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0; } static absl::string_view RemoveSuffix(absl::string_view str, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 47c7eb13b6e4cc05a23f82b8d2a25249f4b82ac0..99e6702d14aed8ffb148adec2bdd02dbc7c3c7e3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -26,7 +26,7 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); -bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); +bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config); absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config); absl::optional> LlvmIrGemmTileSize( const HloModuleConfig& config); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index a9febe891b5e9d1eb9e6b297952b50d1d26a3396..d8878e622c0500fc5328aa6c295a9e24a3a037f7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -84,31 +84,8 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; extern const char* const kParallelForkJoinSymbolName = "__xla_cpu_runtime_ParallelForkJoin"; -extern const char* const kKeyValueSortPREDSymbolName = - "__xla_cpu_runtime_KeyValueSortPRED"; -extern const char* const kKeyValueSortS8SymbolName = - "__xla_cpu_runtime_KeyValueSortS8"; -extern const char* const kKeyValueSortU8SymbolName = - "__xla_cpu_runtime_KeyValueSortU8"; -extern const char* const kKeyValueSortS16SymbolName = - "__xla_cpu_runtime_KeyValueSortS16"; -extern const char* const kKeyValueSortU16SymbolName = - "__xla_cpu_runtime_KeyValueSortU16"; -extern const char* const kKeyValueSortF16SymbolName = - "__xla_cpu_runtime_KeyValueSortF16"; -extern const char* const kKeyValueSortS32SymbolName = - "__xla_cpu_runtime_KeyValueSortS32"; -extern const char* const kKeyValueSortU32SymbolName = - "__xla_cpu_runtime_KeyValueSortU32"; -extern const char* const kKeyValueSortF32SymbolName = - "__xla_cpu_runtime_KeyValueSortF32"; -extern const char* const kKeyValueSortS64SymbolName = - "__xla_cpu_runtime_KeyValueSortS64"; -extern const char* const kKeyValueSortU64SymbolName = - "__xla_cpu_runtime_KeyValueSortU64"; -extern const char* const kKeyValueSortF64SymbolName = - "__xla_cpu_runtime_KeyValueSortF64"; - +extern const char* const kKeyValueSortSymbolName = + "__xla_cpu_runtime_KeyValueSort"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index b2e760a224ad8eaa61dae57b0f9cece04a7e54ae..3a2b44d8c1a80128d3577c374e751e73a89e9d59 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -64,18 +64,7 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; extern const char* const kParallelForkJoinSymbolName; -extern const char* const kKeyValueSortPREDSymbolName; -extern const char* const kKeyValueSortS8SymbolName; -extern const char* const kKeyValueSortU8SymbolName; -extern const char* const kKeyValueSortS16SymbolName; -extern const char* const kKeyValueSortU16SymbolName; -extern const char* const kKeyValueSortF16SymbolName; -extern const char* const kKeyValueSortS32SymbolName; -extern const char* const kKeyValueSortU32SymbolName; -extern const char* const kKeyValueSortF32SymbolName; -extern const char* const kKeyValueSortS64SymbolName; -extern const char* const kKeyValueSortU64SymbolName; -extern const char* const kKeyValueSortF64SymbolName; +extern const char* const kKeyValueSortSymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index 1ae3aa57111e3a3b7ac18b4907c5c282edf89b7e..4e8c98678309fa4d573f1aac1290c9afc87643a4 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -162,11 +162,12 @@ TEST_P(EigenMatMulTest, DoIt) { CheckMatrixMultiply(*a, *b, *c); } -INSTANTIATE_TEST_CASE_P(EigenMatMulTestInstantiaion, EigenMatMulTest, - ::testing::Combine(::testing::ValuesIn(MatMulShapes), - ::testing::Bool(), ::testing::Bool(), - ::testing::Bool()), - EigenMatMulTest::Name); +INSTANTIATE_TEST_SUITE_P(EigenMatMulTestInstantiaion, EigenMatMulTest, + ::testing::Combine(::testing::ValuesIn(MatMulShapes), + ::testing::Bool(), + ::testing::Bool(), + ::testing::Bool()), + EigenMatMulTest::Name); #ifdef INTEL_MKL class MKLMatMulTest : public CpuRuntimeTest, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 1457582ac19c27e5c3150b4667e6af505345a6bd..fae9670051a654f38f09856368ffb700b0c7a085 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" @@ -97,7 +96,7 @@ Status CpuTransferManager::TransferLiteralToInfeed( VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); - if (!ShapeUtil::IsTuple(shape)) { + if (!shape.IsTuple()) { int64 size = GetByteSizeRequirement(shape); return TransferBufferToInfeed(executor, size, literal.untyped_data()); } @@ -178,7 +177,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, Status CpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, MutableBorrowingLiteral literal) { - if (!ShapeUtil::IsTuple(literal_shape)) { + if (!literal_shape.IsTuple()) { int64 size = GetByteSizeRequirement(literal_shape); // Note: OSS build didn't like implicit conversion from // literal_shape.dimensions() to the array slice on 2017-07-10. diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc index 3ae64142cd7e32d3aa8d50870efaf94698c06440..c3c6847b7b77e2fb0470630815de9f5d7a6c5b9c 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.cc +++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc @@ -77,17 +77,16 @@ StatusOr Disassembler::DisassembleObjectFile( } // Sort the symbols in increasing address order. - std::sort( - symbols.begin(), symbols.end(), - [](const llvm::object::SymbolRef& a, const llvm::object::SymbolRef& b) { - // getAddress returns a Expected object. Assert there is no error - // before extracting the address. - llvm::Expected a_address_or_error = a.getAddress(); - CHECK(a_address_or_error); - llvm::Expected b_address_or_error = b.getAddress(); - CHECK(b_address_or_error); - return a_address_or_error.get() < b_address_or_error.get(); - }); + absl::c_sort(symbols, [](const llvm::object::SymbolRef& a, + const llvm::object::SymbolRef& b) { + // getAddress returns a Expected object. Assert there is no error + // before extracting the address. + llvm::Expected a_address_or_error = a.getAddress(); + CHECK(a_address_or_error); + llvm::Expected b_address_or_error = b.getAddress(); + CHECK(b_address_or_error); + return a_address_or_error.get() < b_address_or_error.get(); + }); // Construct ArrayRef pointing to section contents. llvm::StringRef section_content_string; diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 97f9b85a606e140fd7f3b1e3ecfb0dd5ba289f03..2bf22ec6e43ea9944935a4d0d5dcd22c5d190c17 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -26,7 +26,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -41,932 +44,165 @@ namespace xla { using llvm_ir::SetToFirstInsertPoint; namespace cpu { - namespace { -// Provides tiled access to an in-memory rank 2 array. -class MemoryTile { - public: - // Constructs a MemoryTile that can operate on tiles consisting of - // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at - // `major_dim_offset` in the major dimension. The tile size along the minor - // dimension is the vector size, and that is implicitly determined by `vsl`. - MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b, - llvm::Value* matrix, int64 matrix_size_along_minor_dim, - llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) - : vsl_(vsl), b_(b) { - pointers_.reserve(tile_size_along_major_dim); - for (int64 i = 0; i < tile_size_along_major_dim; i++) { - llvm::Value* total_offset = - b->CreateMul(b->getInt64(matrix_size_along_minor_dim), - b->CreateAdd(b->getInt64(i), major_dim_offset)); - pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); - } - } - - // Load a tile consisting of `tile_size_along_major_dim` vectors from position - // {major: `major_dim_offset`, minor: `minor_dim_offset`}. - // - // Note: `major_dim_offset` is a parameter to the constructor. - std::vector LoadTile(llvm::Value* minor_dim_offset) const { - std::vector result; - result.reserve(pointers_.size()); - for (const auto& pointer : pointers_) { - result.push_back(vsl_->LoadVector(pointer, minor_dim_offset)); - } - return result; - } - - // Stores `tile` to position {major: `major_dim_offset`, minor: - // `minor_dim_offset`}. - // - // Note: `major_dim_offset` is a parameter to the constructor. - void StoreTile(absl::Span tile, - llvm::Value* minor_dim_offset) const { - CHECK_EQ(tile.size(), pointers_.size()); - for (int64 i = 0; i < pointers_.size(); i++) { - vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); - } - } - - // Loads a tile of size [`tile_size_along_major_dim`, - // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, - // minor: `minor_dim_offset`} and then broadcasts each element into a vector - // of size vsl_.vector_size(). The (i,j)'th element of the return value is - // the (i,j)'th element in the tile broadcasted into an LLVM vector. - // - // Note: `major_dim_offset` is a parameter to the constructor. - std::vector> LoadBroadcastTile( - llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { - std::vector> result; - result.resize(pointers_.size()); - for (int64 i = 0; i < pointers_.size(); i++) { - for (int64 j = 0; j < tile_size_along_middle_dim; j++) { - result[i].push_back(vsl_->LoadBroadcast( - pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j)))); - } - } - return result; - } - - private: - VectorSupportLibrary* vsl_; - llvm::IRBuilder<>* b_; - std::vector pointers_; -}; - -// The base class for the classes representing the GEMV emitter configurations. -// -// The IR emitted (modulo the LLVM values representing the input and output -// buffers) by the row major and column major GEMV emitters should be a function -// of their configuration. This is important because their configuration is -// used as a key to cache the generated IR. -class GemvConfig { - public: - // Mixin for convenience. - template - struct User { - public: - PrimitiveType scalar_type() const { - return derived().config().scalar_type(); - } - int64 tile_rows() const { return derived().config().tile_rows(); } - int64 tile_cols() const { return derived().config().tile_cols(); } - int64 m() const { return derived().config().m(); } - int64 k() const { return derived().config().k(); } - int64 has_addend() const { return derived().config().has_addend(); } - - private: - const T& derived() const { return *static_cast(this); } - }; +// Returns true if we should call into multi-threaded Eigen routines. +bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) { + return config.debug_options().xla_cpu_multi_thread_eigen(); +} - PrimitiveType scalar_type() const { return scalar_type_; } - int64 tile_rows() const { return tile_rows_; } - int64 tile_cols() const { return tile_cols_; } - int64 m() const { return m_; } - int64 k() const { return k_; } - bool has_addend() const { return has_addend_; } - - string GetCacheKey() const { - return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_", - tile_rows(), "_", tile_cols(), "_", m(), "_", k(), - has_addend() ? "_with_addend" : ""); +// Represents a dot operation. We use this in lieu of an `HloInstruction` +// because we want to be able to create this for the "inner" dot operation in a +// batch dot, for which there is no separate HLO instruction. +struct DotInfo { + Shape lhs_shape; + Shape rhs_shape; + Shape result_shape; + DotDimensionNumbers dim_nums; + + DotInfo() = default; + + explicit DotInfo(const HloInstruction& instr) { + CHECK_EQ(instr.opcode(), HloOpcode::kDot); + lhs_shape = instr.operand(0)->shape(); + rhs_shape = instr.operand(1)->shape(); + result_shape = instr.shape(); + dim_nums = instr.dot_dimension_numbers(); } - - protected: - explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows, - int64 tile_cols, int64 m, int64 k, bool has_addend) - : name_(std::move(name)), - scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), - has_addend_(has_addend) {} - - private: - string name_; - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; - bool has_addend_; }; -// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the -// layout of the vector does not matter). This implementation uses a tiling -// scheme to improve performance. -// -// We logically separate the LHS matrix into four segments: -// -// +----------------------+---+ -// | | | -// | | | -// | A | B | -// | | | -// | | | -// | | | -// +----------------------+---+ -// | C | D | -// +----------------------+---+ -// -// where A is the largest submatrix of the LHS that can be evenly dividied into -// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: -// -// +---+---+---+---+ +--+--+--+--+ -// |M00|M10|M20|M30| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M01|M11|M21|M31| and |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M02|M12|M22|M32| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M03|M13|M23|M33| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// -// (Legend: rows are horizontal and columns are vertical; and each column is one -// llvm::Value of a vector type) -// -// where: -// -// a. The left tile is from the column major left matrix. -// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3] -// vector loaded from the RHS vector. -// -// As we iterate through the column dimension, we compute the change to the -// result vector by an elementwise multiplication between the two tiles above -// followed by a reduction along the major dimension: -// -// +-----------------------------------+ -// | M00*V0 + M10*V1 + M20*V2 + M30*V3 | -// +-----------------------------------+ -// | M01*V0 + M11*V1 + M21*V2 + M31*V3 | -// Result[R:R+4] += +-----------------------------------+ -// | M02*V0 + M12*V1 + M22*V2 + M32*V3 | -// +-----------------------------------+ -// | M03*V0 + M13*V1 + M23*V2 + M33*V3 | -// +-----------------------------------+ -// -// Where R is the starting row for the tile. -// -// We have an inner epilogue loop to deal with the "C" submatrix and an outer -// epilogue loop to deal with the B,D submarix. -// -// TODO(sanjoy): We should investigate if using gather loads and scatter stores -// can be used here have the same inner loop for both column-major and row-major -// matrix-vector products. -class ColumnMajorMatrixVectorProductEmitter - : public GemvConfig::User { - public: - class Config : public GemvConfig { - public: - explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, - int64 m, int64 k, bool has_addend) - : GemvConfig(/*name=*/"col_major_gemv", scalar_type, - /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, - /*k=*/k, /*has_addend=*/has_addend) {} - }; - - ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result, - llvm::IRBuilder<>* b) - : config_(config), - lhs_(lhs), - rhs_(rhs), - addend_(addend), - result_(result), - b_(b), - ksl_(b_), - vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") { - CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); - CHECK(!has_addend() || addend != nullptr); - } - - void Emit(); - - const Config& config() const { return config_; } - - private: - void EmitOuterLoopBody(llvm::Value* column, int64 column_count, - bool is_first_column); - - MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { - return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/m(), - /*major_dim_offset=*/column_start, - /*tile_size_along_major_dim=*/column_count); - } - - // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous - // sequence of `count` values, each one broadcasted to the vector width. - std::vector LoadRhsTile(llvm::Value* offset, int64 count) { - llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); - std::vector result; - result.reserve(count); - for (int64 i = 0; i < count; i++) { - result.push_back(vsl_.LoadBroadcast(base_pointer, i)); - } - return result; - } - - void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, - const std::vector& rhs_tile, - int64 columns, bool is_first_column); - - void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, - bool is_first_tiled_column); - - Config config_; - llvm::Value* lhs_; - llvm::Value* rhs_; - llvm::Value* addend_; - llvm::Value* result_; - llvm::IRBuilder<>* b_; - KernelSupportLibrary ksl_; - VectorSupportLibrary vsl_; +// Dictates how a dot operation is implemented. +enum class DotImplementationStrategy { + // The dot operation is lowered into LLVM IR that implements a naive nested + // loop that computes the result one element at a time. This is our + // "fallback"; we don't really want this to kick in for any non-trival dot + // operation. + kNaiveLlvmIr, + + // The dot operation is lowered into LLVM IR that implements a tiled + // Matrix*Vector operation. This strategy also allows fusing in a bias add + // into the dot. The matrix can be row major or column major, both are + // supported. + kTiledLlvmIrGemv, + + // The dot operation is lowered into LLVM IR that implemetns a tiled + // Matrix*Matrix operation. No fusions are supported. The two inputs + // and the output have to be row major. + kTiledLlvmIrGemm, + + // The dot operation is lowered into a call into an Eigen routine. No fusions + // are supported today. The two inputs and the output have to be row major. + // However, we do allow transposing either the LHS or the RHS as part of the + // GEMM -- we expose this flexibility as flexibility in the contraction + // dimensions, but we can also see this as flexibility in the input layouts. + kEigen, }; -void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( - llvm::Value* column, int64 column_count, bool is_first_column) { - MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column, - /*column_count=*/column_count); - - std::vector rhs_tile = - LoadRhsTile(column, /*count=*/column_count); - EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile, - /*columns=*/column_count, is_first_column); - EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); -} - -void ColumnMajorMatrixVectorProductEmitter::Emit() { - // See the comment on the class declaration for the algorithm used here. - int64 column_remainder = k() % tile_cols(); - int64 column_limit = k() - column_remainder; - - ksl_.ForReturnVoid("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), - [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols(), is_first_column); - }); - - if (column_remainder != 0) { - EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder, - column_limit == 0); - } -} - -void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, - int64 columns, bool is_first_column) { - int64 row_limit = m() - (m() % tile_rows()); - - ksl_.ForReturnVoid( - "dot.inner.tiled", /*start=*/0, /*end=*/row_limit, - /*step=*/tile_rows(), [&](llvm::Value* row) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); - llvm::Value* accumulator = - is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) - : vsl_.GetZeroVector()) - : vsl_.LoadVector(result_, row); - for (int i = 0; i < columns; i++) { - accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); - } - vsl_.StoreVector(accumulator, result_, row); - }); -} - -void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( - llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { - int64 row_start = m() - (m() % tile_rows()); - if (row_start == m()) { - return; - } - - llvm::Value* columns_llvm = b_->getInt64(columns); - - // for (col = current_tile_col; col < (columns + current_tile_col); col++) - // for (row = row_start, row < m_; row++) { - // result[row] += lhs[row, col] * rhs[col] - // // Also take into account that if col is 0 then result[row] is not - // // initialized. - // } - - ksl_.ForReturnVoid( - "dot.inner.epilg.outer", /*start=*/current_tile_col, - /*end=*/b_->CreateAdd(columns_llvm, current_tile_col), - /*step=*/1, /*peel_first_iteration=*/false, - [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { - llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); - llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m())); - llvm::Value* lhs_base_pointer = - vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.ForReturnVoid( - "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), - /*step=*/1, [&](llvm::Value* scalar_row) { - llvm::Value* product = vsl_.Mul( - vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); - llvm::Value* setting_result_first_time = b_->CreateAnd( - is_first_scalar_col, b_->getInt1(is_first_tiled_column)); - ksl_.IfReturnVoid( - setting_result_first_time, - /*true_block_generator=*/ - [&]() { - if (addend_) { - vsl_.StoreScalar( - vsl_.Add(vsl_.LoadScalar(addend_, scalar_row), - product), - result_, scalar_row); - } else { - vsl_.StoreScalar(product, result_, scalar_row); - } - }, - /*false_block_generator=*/ - [&]() { - vsl_.StoreScalar( - vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), - result_, scalar_row); - }); - }); - }); -} +// Returns the implementation strategy for a dot with the configuration +// `dot_info`. +DotImplementationStrategy GetDotImplementationStrategy( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features); -// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the -// layout of the vector does not matter). This implementation uses a tiling -// scheme to improve performance. -// -// We logically separate the LHS matrix into four segments: -// -// +----------------------+---+ -// | | | -// | | | -// | A | B | -// | | | -// | | | -// | | | -// +----------------------+---+ -// | C | D | -// +----------------------+---+ -// -// where A is the largest submatrix of the LHS that can be evenly dividied into -// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: -// -// +---+---+---+---+ -// |M00|M10|M20|M30| -// +---+---+---+---+ +--+--+--+--+ -// |M01|M11|M21|M31| and |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M02|M12|M22|M32| -// +---+---+---+---+ -// |M03|M13|M23|M33| -// +---+---+---+---+ -// -// (Legend: rows are horizontal and columns are vertical; and each row is one -// llvm::Value of a vector type) -// -// where: -// -// a. The left tile is loaded from the row major left matrix. -// b. The right vector is loaded from the RHS vector. -// -// We keep 4 vector accumulators accumulating the following four vector -// expressions as we iterate over the row dimension: -// -// +------+------+------+------+ -// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4) -// +------+------+------+------+ -// -// In the end we do a horizontal reduction over these 4 vector accumulators to -// get 4 values in the result vector. -// -// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer -// epilogue loop to deal with the C,D submatrix. -class RowMajorMatrixVectorProductEmitter - : public GemvConfig::User { +// Helper class for emitting LLVM IR to perform the dot operation. +class DotOpEmitter { public: - class Config : public GemvConfig { - public: - explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, - int64 m, int64 k, bool has_addend) - : GemvConfig(/*name=*/"row_major_gemv", scalar_type, - /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, - /*k=*/k, /*has_addend=*/has_addend) {} - }; - - RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result, llvm::IRBuilder<>* b) - : config_(config), - lhs_(lhs), - rhs_(rhs), - addend_(addend), - result_(result), - b_(b), - ksl_(b_), - vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") { - CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); - CHECK(!has_addend() || addend != nullptr); - } - - void Emit(); - - const Config& config() const { return config_; } + explicit DotOpEmitter(DotInfo dot_info, string dot_hlo_name, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features); + + // Emits the IR to perform the dot operation. + Status Emit(); private: - MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { - return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/k(), - /*major_dim_offset=*/row_start, - /*tile_size_along_major_dim=*/row_count); - } - - void EmitOuterLoopBody(llvm::Value* row, int64 row_count); - - void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows, - std::vector* vector_accumulators); - - void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, - std::vector* scalar_accumulators); - - Config config_; - llvm::Value* lhs_; - llvm::Value* rhs_; - llvm::Value* addend_; - llvm::Value* result_; - llvm::IRBuilder<>* b_; - KernelSupportLibrary ksl_; - VectorSupportLibrary vsl_; -}; - -void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, - int64 row_count) { - MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row, - /*row_count=*/row_count); - std::vector vector_accumulators; - std::vector scalar_accumulators; - for (int i = 0; i < row_count; i++) { - vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); - scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); - } - EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count, - &vector_accumulators); - EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, - &scalar_accumulators); - - std::vector accumulator_values; - std::transform( - vector_accumulators.begin(), vector_accumulators.end(), - std::back_inserter(accumulator_values), - [](const VectorVariable& vector_var) { return vector_var.Get(); }); - - std::vector horizontal_sums; - if (row_count == vsl_.vector_size()) { - if (addend_) { - horizontal_sums = vsl_.ComputeHorizontalSums( - std::move(accumulator_values), vsl_.LoadVector(addend_, row)); - } else { - horizontal_sums = - vsl_.ComputeHorizontalSums(std::move(accumulator_values)); - } - } else { - horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values)); - } - - for (int i = 0; i < row_count; i++) { - llvm::Value* result_value = - vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); - llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row); - if (addend_ && row_count != vsl_.vector_size()) { - result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); - } - vsl_.StoreScalar(result_value, result_, offset); - } -} + // Emits instructions to perform a scalar dot product (a multiply of the + // LHS and RHS) and store the results in the target. + Status EmitScalarDot(); -void RowMajorMatrixVectorProductEmitter::Emit() { - // See the comment on the class declaration for the algorithm used here. - int64 row_remainder = m() % tile_rows(); - int64 row_limit = m() - row_remainder; + // Emits a call to the CPU runtime to perform the matrix multiply. + Status EmitCallToRuntime(); - ksl_.ForReturnVoid( - "dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); - - if (row_remainder != 0) { - EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder); - } -} + // Represents the dimensions of a matrix-matrix multiply operation. + struct MatMultDims { + // The number of rows in the LHS. + int64 m; -void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - MemoryTile* lhs_memory_tile, int64 rows, - std::vector* vector_accumulators) { - int64 column_limit = k() - (k() % tile_cols()); - - ksl_.ForReturnVoid("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols(), [&](llvm::Value* col) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); - llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); - for (int i = 0; i < rows; i++) { - llvm::Value* old_sum = (*vector_accumulators)[i].Get(); - (*vector_accumulators)[i].Set(vsl_.Add( - old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); - } - }); -} + // The number of columns in the LHS, which is also must be equal to the + // number of rows in the RHS. + int64 k; -void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( - llvm::Value* current_tile_row, int64 rows, - std::vector* scalar_accumulators) { - int64 column_start = k() - (k() % tile_cols()); - if (column_start == k()) { - return; - } + // The number of columns on the RHS. + int64 n; - for (int r = 0; r < rows; r++) { - llvm::Value* total_offset = b_->CreateMul( - b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k())); - llvm::Value* lhs_base_pointer = - vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.ForReturnVoid( - "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), - /*step=*/1, [&](llvm::Value* scalar_col) { - llvm::Value* product = - vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), - vsl_.LoadScalar(rhs_, scalar_col)); - llvm::Value* old_value = (*scalar_accumulators)[r].Get(); - (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); - }); - } -} + // True if the LHS matrix is column major. + bool lhs_column_major; -// This class implements a tiled matrix multiplication algorithm, intended for -// multiplying small matrices that don't need cache tiling. -// -// In the future this can be used as the innermost GEBP loop in a GEMM kernel as -// described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of -// high-performance matrix multiplication." ACM Transactions on Mathematical -// Software (TOMS) 34.3 (2008): 12.". -// -// This only supports canonical dot operations (i.e. where the lhs contraction -// dimension is 1 and the rhs contraction dimension is 0) over row major -// matrices. -class TiledSmallGemmEmitter { - public: - // Describe the dimensions of the kernel. - class Dimensions { - public: - explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} + // True if the LHS contraction dimension is not 1. + bool lhs_non_canonical; - int64 m() const { return m_; } - int64 k() const { return k_; } - int64 n() const { return n_; } + // True if the RHS matrix is column major. + bool rhs_column_major; - string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); } + // True if the RHS contraction dimension is not 0. + bool rhs_non_canonical; - private: - const int64 m_; - const int64 k_; - const int64 n_; + // True if the result matrix is column major. + bool target_column_major; }; - // Represents the configuration of the emitter. The LLVM IR emitted by the - // emitter, modulo the LLVM values holding the input and output buffers, must - // be a function of the instance of `Config` passed to it. - // - // `dims` holds the matrix multiplication dimensions. - // - // `max_vectorization_width` is the maximum vector width (i.e. the width of - // the largest vector register we will use). This can be larger than the - // largest vector register supported by the machine -- LLVM will legalize - // these large vector widths into legally sized vectors. - // - // `max_vector_count` is the maximum number of vectors of size - // `max_vectorization_width` that we will attempt to process at once. - // - // `min_vectorization_width` is the smallest vector width the emitter will use - // -- below that it will devolve to using a scalar loop. - // - // The innermost reduction loop executes the matrix multiply in tiles of size - // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, - // ] in the RHS. - class Config { - public: - explicit Config(PrimitiveType scalar_type, Dimensions dims, - int64 max_vectorization_width, int64 max_vector_count, - int64 min_vectorization_width, int64 tile_size_m, - int64 tile_size_k) - : scalar_type_(scalar_type), - dims_(dims), - max_vectorization_width_(max_vectorization_width), - max_vector_count_(max_vector_count), - min_vectorization_width_(min_vectorization_width), - tile_size_m_(tile_size_m), - tile_size_k_(tile_size_k) {} - - string GetCacheKey() const { - return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_", - dims().ToString(), "_", max_vectorization_width(), - "_", min_vectorization_width(), "_", tile_size_m(), - "_", tile_size_k()); - } + // Get the MatMultDims instance for the dot product this DotOpEmitter + // represents. Precondition: the dot is of rank 2 (and thus its operands are + // of rank 2 as well). + MatMultDims GetMatMultDims() const; - PrimitiveType scalar_type() const { return scalar_type_; } - Dimensions dims() const { return dims_; } - int64 max_vectorization_width() const { return max_vectorization_width_; } - int64 max_vector_count() const { return max_vector_count_; } - int64 min_vectorization_width() const { return min_vectorization_width_; } - - int64 tile_size_m() const { return tile_size_m_; } - int64 tile_size_k() const { return tile_size_k_; } - - private: - PrimitiveType scalar_type_; - Dimensions dims_; - int64 max_vectorization_width_; - int64 max_vector_count_; - int64 min_vectorization_width_; - int64 tile_size_m_; - int64 tile_size_k_; - }; + // Lowers the dot operation as a tiled Matrix*Vector loop. + void EmitTiledLlvmIrGemv(); - // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies - // `lhs` with `rhs` and stores the result in `result`. - explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* result, - llvm::IRBuilder<>* b) - : lhs_(lhs), - rhs_(rhs), - result_(result), - config_(config), - b_(b), - ksl_(b_) { - CHECK(max_vectorization_width() > 0 && - IsPowerOfTwo(static_cast(max_vectorization_width()))); - CHECK_GT(max_vector_count(), 0); - CHECK(min_vectorization_width() > 0 && - IsPowerOfTwo(static_cast(min_vectorization_width()))); - CHECK_GE(max_vectorization_width(), min_vectorization_width()); - CHECK_GT(tile_size_k(), 0); - } + // Lowers the dot operation as a tiled Matrix*Matrix loop. + void EmitTiledLlvmIrGemm(); - void Emit(); + // Lowers the dot operation as a naive nested loop that computes the result + // one element at a time. + void EmitNaiveLlvmIrGemm(); - private: - // The HandleResiduesOnX helpers split the iteration space for dimension X - // into a multiple of the tile size on dimension X and an epilogue. These - // helpers ultimately call into `EmitTiledGemm` for emitting the - // tiled GEMM kernel. - - void HandleResiduesOnN(); - void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start, - llvm::Value* n_end); - void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k, - llvm::Value* k_start, llvm::Value* k_end, - llvm::Value* n_start, llvm::Value* n_end); - - // This emits a tiled GEMM kernel. For a detailed description see the comment - // on the implementation. - void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, - llvm::Value* k_start, llvm::Value* k_end, - llvm::Value* n_start, llvm::Value* n_end, - int64 tile_size_m, llvm::Value* m_start, - llvm::Value* m_end); - - llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); } - - Config config() const { return config_; } - Dimensions dims() const { return config().dims(); } - - int64 max_vectorization_width() const { - return config().max_vectorization_width(); + // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector + // registers. + int64 GetGemvTilingFactor() const { + const int64 kDefaultTilingFactor = 8; + return options::LlvmIrGemvTilingFactor(hlo_module_config_) + .value_or(kDefaultTilingFactor); } - int64 max_vector_count() const { return config().max_vector_count(); } - int64 min_vectorization_width() const { - return config().min_vectorization_width(); - } - int64 tile_size_m() const { return config().tile_size_m(); } - int64 tile_size_k() const { return config().tile_size_k(); } - PrimitiveType scalar_type() const { return config().scalar_type(); } - llvm::Value* lhs_; - llvm::Value* rhs_; - llvm::Value* result_; - Config config_; + std::tuple GetGemmTileSize() const { + // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz + // + // TODO(b/80093688): Tune for other architectures and centralize this + // information in one place. + const std::tuple kDefaultTileSize = + std::tuple(11, 9, 1); + return options::LlvmIrGemmTileSize(hlo_module_config_) + .value_or(kDefaultTileSize); + } + DotInfo dot_info_; + string dot_hlo_name_; + const llvm_ir::IrArray& target_array_; + const llvm_ir::IrArray& lhs_array_; + const llvm_ir::IrArray& rhs_array_; + const llvm_ir::IrArray* addend_array_; + llvm::Value* executable_run_options_value_; llvm::IRBuilder<>* b_; - KernelSupportLibrary ksl_; + const HloModuleConfig& hlo_module_config_; + const TargetMachineFeatures& target_machine_features_; }; - -void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); } - -void TiledSmallGemmEmitter::HandleResiduesOnN() { - // We can only iterate the `n` dimension for an extent that is divisible by - // the vectorization width. So we emit an outer loop that first processes the - // largest extent in `n` that is divisible by max_vectorization_width, then - // the largest remaining extent that is divisible by max_vectorization_width / - // 2 etc. - - int64 current_vectorization_width = - max_vector_count() * max_vectorization_width(); - int64 current_vector_count = max_vector_count(); - - int64 n_start = 0; - while (n_start != dims().n() && - current_vectorization_width >= min_vectorization_width()) { - int64 n_end = dims().n() - (dims().n() % current_vectorization_width); - if (n_start != n_end) { - VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_, - "gemm"); - HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); - n_start = n_end; - } - if (current_vector_count == 1) { - current_vectorization_width /= 2; - } else { - current_vector_count--; - current_vectorization_width = - current_vector_count * max_vectorization_width(); - } - } - - if (n_start != dims().n()) { - VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); - ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { - llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); - HandleResiduesOnK(&vsl, n_i, n_i_next); - }); - } -} - -void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, - llvm::Value* n_start, - llvm::Value* n_end) { - int64 k_start = 0; - int64 k_end = dims().k() - (dims().k() % tile_size_k()); - if (k_end != k_start) { - HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), - n_start, n_end); - k_start = k_end; - } - - if (k_start != dims().k()) { - HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start), - GetInt64(dims().k()), n_start, n_end); - } -} - -void TiledSmallGemmEmitter::HandleResiduesOnM( - VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, - llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { - const int64 m_end = dims().m() - dims().m() % tile_size_m(); - EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), - GetInt64(0), GetInt64(m_end)); - - if (m_end != dims().m()) { - EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, - dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); - } -} - -// The loop structure is: -// -// Iterate over dimension M as m: -// Iterate over dimension N as n: -// Iterate over dimension K as k: -// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) -// -// I.e. a just a tiled version of a "naive" GEMM. -// -// The tiling scheme is as follows: -// -// Let the LHS be: -// -// +----+----+----+ -// | a0 | b0 | c0 | . -// +----+----+----+ . -// | a1 | b1 | c1 | . -// +----+----+----+ -// .. .. -// -// and the RHS be: -// -// +----+----+----+----+ -// | p0 | p1 | p2 | p3 | . -// +----+----+----+----+ . -// | q0 | q1 | q2 | q3 | . -// +----+----+----+----+ -// | r0 | r1 | r2 | r3 | . -// +----+----+----+----+ . -// ...... ...... -// -// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted -// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] -// matrix that we can increment the result matrix by. -// -// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank -// 3 array, L, of dimension [2,3,4]: -// -// L[0,_,_] * L[1,_,_] -// * -// +----+----+----+----+ * +----+----+----+----+ -// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | -// +----+----+----+----+ * +----+----+----+----+ -// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | -// +----+----+----+----+ * +----+----+----+----+ -// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | -// +----+----+----+----+ * +----+----+----+----+ -// -// -// Then we FMA L[0,_,_] with the RHS to get the first row of the result and -// L[1,_,_] with the RHS to get the second row of the result. For example, -// L[0,_,_] is computed as: -// -// +----+----+----+----+ +----+----+----+----+ -// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + -// +----+----+----+----+ +----+----+----+----+ -// -// +----+----+----+----+ +----+----+----+----+ -// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + -// +----+----+----+----+ +----+----+----+----+ -// -// +----+----+----+----+ +----+----+----+----+ -// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | -// +----+----+----+----+ +----+----+----+----+ -// -// to get: -// -// +-------------------+-------------------+-------------------+--------- -// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... -// +-------------------+-------------------+-------------------+--------- -void TiledSmallGemmEmitter::EmitTiledGemm( - VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, - llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, - int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { - ksl_.ForReturnVoid( - "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { - MemoryTile result_memory_tile( - vsl, b_, /*matrix=*/result_, - /*matrix_size_along_minor_dim=*/dims().n(), - /*major_dim_offset=*/m_i, - /*tile_size_along_major_dim=*/tile_size_m); - MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/dims().k(), - /*major_dim_offset=*/m_i, - /*tile_size_along_major_dim=*/tile_size_m); - ksl_.ForReturnVoid( - "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { - TileVariable result_tile_var(vsl, - result_memory_tile.LoadTile(n_i)); - ksl_.ForReturnVoid( - "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { - MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i, - tile_size_k); - std::vector> lhs_tile = - lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); - std::vector rhs_tile = - rhs_memory_tile.LoadTile(n_i); - std::vector result_tile = - result_tile_var.Get(); - for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { - for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { - result_tile[r_m_i] = - vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], - result_tile[r_m_i]); - } - } - result_tile_var.Set(result_tile); - }); - - result_memory_tile.StoreTile(result_tile_var.Get(), n_i); - }); - }); -} - } // namespace -DotOpEmitter::DotOpEmitter(const HloInstruction& dot, +DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, @@ -975,7 +211,8 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, llvm::IRBuilder<>* b, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) - : dot_(dot), + : dot_info_(std::move(dot_info)), + dot_hlo_name_(std::move(dot_hlo_name)), target_array_(target_array), lhs_array_(lhs_array), rhs_array_(rhs_array), @@ -985,58 +222,9 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} -/* static */ Status DotOpEmitter::EmitDotOperation( - const HloInstruction& dot, const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features) { - PrimitiveType type = target_array.GetShape().element_type(); - TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); - DotOpEmitter dot_emitter(dot, target_array, lhs_array, rhs_array, - addend_array, executable_run_options_value, b, - hlo_module_config, target_machine_features); - return dot_emitter.Emit(); -} - -bool DotOpEmitter::EmitSmallGemmIfProfitable( - const DotOpEmitter::MatMultDims& mat_mult_dims) { - if (ShouldUseMultiThreadedEigen()) { - return false; - } - - if (!EnableExperimentalLlvmIrGemm()) { - // TODO(sanjoy): We should make these numbers micro-arch specific. - bool small_gemm = mat_mult_dims.k <= 128 && - ((mat_mult_dims.m <= 32 && mat_mult_dims.n <= 128) || - (mat_mult_dims.m <= 128 && mat_mult_dims.n <= 32)); - if (!small_gemm) { - return false; - } - } - - if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) { - return false; - } - - PrimitiveType primitive_type = dot_.shape().element_type(); - - switch (primitive_type) { - default: - return false; - - case F32: - case F64: - case S32: - case S64: - break; - } - - if (!(mat_mult_dims.lhs_column_major == mat_mult_dims.rhs_column_major && - mat_mult_dims.rhs_column_major == mat_mult_dims.target_column_major)) { - return false; - } +void DotOpEmitter::EmitTiledLlvmIrGemm() { + PrimitiveType primitive_type = dot_info_.result_shape.element_type(); + MatMultDims mat_mult_dims = GetMatMultDims(); llvm::Value* lhs = lhs_array_.GetBasePointer(); llvm::Value* rhs = rhs_array_.GetBasePointer(); @@ -1051,9 +239,8 @@ bool DotOpEmitter::EmitSmallGemmIfProfitable( } int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - b_->CreateMemSet( - target, b_->getInt8(0), size_bytes, - target_machine_features_.minimum_alignment_for_allocation(size_bytes)); + b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes, + /*Align=*/1); int64 max_target_vector_width = target_machine_features_.vector_register_num_elements( @@ -1063,47 +250,28 @@ bool DotOpEmitter::EmitSmallGemmIfProfitable( std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = GetGemmTileSize(); - TiledSmallGemmEmitter::Config config( - /*scalar_type=*/primitive_type, - TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, - /*max_vectorization_width=*/max_target_vector_width, - /*max_vector_count=*/tile_size_n_in_vector_width, - /*min_vectorization_width=*/std::min(4, max_target_vector_width), - /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); - - VLOG(2) << "Emitting GEMM kernel in LLVM IR with config " - << config.GetCacheKey(); - const bool enable_fast_math = hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); const bool optimize_for_size = options::OptimizeForSizeRequested(hlo_module_config_); - KernelSupportLibrary::EmitAndCallOutlinedKernel( + EmitSmallGemm( + /*scalar_type=*/primitive_type, + /*m=*/m, /*k=*/k, /*n=*/n, + /*max_vectorization_width=*/max_target_vector_width, + /*max_vector_count=*/tile_size_n_in_vector_width, + /*min_vectorization_width=*/std::min(4, max_target_vector_width), + /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs, + /*rhs=*/rhs, /*result=*/target, b_, /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), lhs, - rhs, target, - [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { - TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs, - /*rhs=*/rhs, - /*result=*/target, b_); - small_gemm_emitter.Emit(); - }); - - return true; + /*optimize_for_size=*/optimize_for_size); } -bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { - if (dot_.shape().dimensions_size() != 2) { - return false; - } - - PrimitiveType primitive_type = dot_.shape().element_type(); +void DotOpEmitter::EmitTiledLlvmIrGemv() { + PrimitiveType primitive_type = dot_info_.result_shape.element_type(); - if (!primitive_util::IsFloatingPointType(primitive_type) && - !primitive_util::IsIntegralType(primitive_type)) { - return false; - } + CHECK(primitive_util::IsFloatingPointType(primitive_type) || + primitive_util::IsIntegralType(primitive_type)); MatMultDims mat_mult_dims = GetMatMultDims(); bool is_column_major_matrix_vector = false; @@ -1144,9 +312,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { } } - if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { - return EmitSmallGemmIfProfitable(mat_mult_dims); - } + CHECK(is_column_major_matrix_vector || is_row_major_matrix_vector); int64 tiling_factor = GetGemvTilingFactor(); CHECK_GT(tiling_factor, 0); @@ -1178,44 +344,27 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (is_column_major_matrix_vector) { VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m << " and k = " << k; - ColumnMajorMatrixVectorProductEmitter::Config config( + EmitColumnMajorGemv( /*scalar_type=*/primitive_type, /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor, - /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); - - KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op, + /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr, + /*result=*/result_op, b_, /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), - lhs_op, rhs_op, - addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, - llvm::Value* addend_op, llvm::Value* result_op) { - ColumnMajorMatrixVectorProductEmitter emitter( - config, lhs_op, rhs_op, addend_op, result_op, b_); - emitter.Emit(); - }); + /*optimize_for_size=*/optimize_for_size); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; - RowMajorMatrixVectorProductEmitter::Config config( + EmitRowMajorGemv( /*scalar_type=*/primitive_type, - /*tile_rows=*/tiling_factor, /*tile_cols=*/vector_register_element_size, - /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); - - KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*tile_rows=*/tiling_factor, + /*tile_cols=*/vector_register_element_size, + /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op, + /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr, + /*result=*/result_op, b_, /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), - lhs_op, rhs_op, - addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, - llvm::Value* addend_op, llvm::Value* result_op) { - RowMajorMatrixVectorProductEmitter emitter(config, lhs_op, rhs_op, - addend_op, result_op, b_); - emitter.Emit(); - }); + /*optimize_for_size=*/optimize_for_size); } - - return true; } Status DotOpEmitter::Emit() { @@ -1241,11 +390,6 @@ Status DotOpEmitter::Emit() { // which performs the sum-of-products (the reduction loop) before storing // the result in the output buffer. - // This routine assumes that the dot operation is not in a parallelized - // enclosing computation. - CHECK( - dot_.parent()->root_instruction()->outer_dimension_partitions().empty()); - const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); @@ -1256,27 +400,41 @@ Status DotOpEmitter::Emit() { return EmitScalarDot(); } - if (EmitLlvmIrDotIfProfitable()) { - return Status::OK(); + switch (GetDotImplementationStrategy(hlo_module_config_, dot_info_, + target_machine_features_)) { + case DotImplementationStrategy::kNaiveLlvmIr: + EmitNaiveLlvmIrGemm(); + return Status::OK(); + + case DotImplementationStrategy::kTiledLlvmIrGemv: + EmitTiledLlvmIrGemv(); + return Status::OK(); + + case DotImplementationStrategy::kTiledLlvmIrGemm: + EmitTiledLlvmIrGemm(); + return Status::OK(); + + case DotImplementationStrategy::kEigen: + return EmitCallToRuntime(); } +} +void DotOpEmitter::EmitNaiveLlvmIrGemm() { CHECK_EQ(addend_array_, nullptr); - if (PotentiallyImplementedAsEigenDot(dot_, target_machine_features_)) { - return EmitCallToRuntime(); - } + const Shape& lhs_shape = lhs_array_.GetShape(); + const Shape& rhs_shape = rhs_array_.GetShape(); + const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special // case where the reduction dimension is 0 for both LHS and RHS. This results // in a vector dot product producing a scalar. - int64 lhs_reduction_dimension = - dot_.dot_dimension_numbers().lhs_contracting_dimensions(0); - int64 rhs_reduction_dimension = - dot_.dot_dimension_numbers().rhs_contracting_dimensions(0); + int64 lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0); + int64 rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0); // Verify the reduction dimension in the two operands are the same size. - TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == - rhs_shape.dimensions(rhs_reduction_dimension)); + CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension), + rhs_shape.dimensions(rhs_reduction_dimension)); bool lhs_reduction_along_minor_dimension = lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0); @@ -1286,7 +444,7 @@ Status DotOpEmitter::Emit() { // Create loop nests which loop through the LHS operand dimensions and the RHS // operand dimensions. The reduction dimension of the LHS and RHS are handled // in a separate innermost loop which performs the sum of products. - llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(&dot_), b_); + llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_); llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest( lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( @@ -1391,8 +549,6 @@ Status DotOpEmitter::Emit() { // Set the IR builder insert point to the exit basic block of the outer most // loop. b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - - return Status::OK(); } Status DotOpEmitter::EmitScalarDot() { @@ -1406,16 +562,20 @@ Status DotOpEmitter::EmitScalarDot() { llvm::Value* rhs_value = rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_); if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) { -#define REAL(x) b_->CreateExtractValue(x, {0}) -#define IMAG(x) b_->CreateExtractValue(x, {1}) - llvm::Value* real = - b_->CreateFSub(b_->CreateFMul(REAL(lhs_value), REAL(rhs_value)), - b_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value))); - llvm::Value* imag = - b_->CreateFAdd(b_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)), - b_->CreateFMul(IMAG(lhs_value), REAL(rhs_value))); -#undef IMAG -#undef REAL + auto get_real = [&](llvm::Value* x) { + return b_->CreateExtractValue(x, {0}); + }; + + auto get_imag = [&](llvm::Value* x) { + return b_->CreateExtractValue(x, {1}); + }; + + llvm::Value* real = b_->CreateFSub( + b_->CreateFMul(get_real(lhs_value), get_real(rhs_value)), + b_->CreateFMul(get_imag(lhs_value), get_imag(rhs_value))); + llvm::Value* imag = b_->CreateFAdd( + b_->CreateFMul(get_real(lhs_value), get_imag(rhs_value)), + b_->CreateFMul(get_imag(lhs_value), get_real(rhs_value))); result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType()); result = b_->CreateInsertValue(result, real, {0}); result = b_->CreateInsertValue(result, imag, {1}); @@ -1435,7 +595,7 @@ Status DotOpEmitter::EmitCallToRuntime() { // The two transpose_... parameters are actually booleans, but we use int32 // to avoid target-dependent calling convention details. - bool multi_threaded = ShouldUseMultiThreadedEigen(); + bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_); bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Type* float_type; @@ -1483,11 +643,13 @@ Status DotOpEmitter::EmitCallToRuntime() { llvm::Function* function = b_->GetInsertBlock()->getParent(); llvm::Module* module = function->getParent(); - llvm::Function* matmul_func = llvm::cast( - module->getOrInsertFunction(fn_name, matmul_type)); - matmul_func->setCallingConv(llvm::CallingConv::C); - matmul_func->setDoesNotThrow(); - matmul_func->setOnlyAccessesArgMemory(); + llvm::FunctionCallee matmul_func = + module->getOrInsertFunction(fn_name, matmul_type); + if (auto* fn = llvm::dyn_cast(matmul_func.getCallee())) { + fn->setCallingConv(llvm::CallingConv::C); + fn->setDoesNotThrow(); + fn->setOnlyAccessesArgMemory(); + } // The Eigen runtime function expects column-major layout. If the matrices are // row major, then use the following identity to compute the product: @@ -1528,11 +690,11 @@ Status DotOpEmitter::EmitCallToRuntime() { } DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { - CHECK_EQ(dot_.shape().dimensions_size(), 2); + CHECK_EQ(dot_info_.result_shape.dimensions_size(), 2); const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); - const DotDimensionNumbers& dim_nums = dot_.dot_dimension_numbers(); + const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; return { /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)), @@ -1546,74 +708,6 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } -// Return whether the given shape is rank 2. -static bool IsRank2(const Shape& shape) { return ShapeUtil::Rank(shape) == 2; } - -// In a gemm operation where output = lhs * rhs, check whether the given shapes -// are valid for the operation. -static bool AreValidGemmShapes( - const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape, - const TargetMachineFeatures& target_machine_features) { - // The inputs and the output must - // 1) be matrices with no padding, and - // 2) have an allowed element type. - PrimitiveType output_primitive_type = output_shape.element_type(); - if (!(output_primitive_type == F64 || output_primitive_type == F32 || - output_primitive_type == F16)) { - return false; - } - - if (!(IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape))) { - return false; - } - - auto is_aligned = [&](const Shape& shape) { - return GetMinimumAlignmentForArray(shape, target_machine_features) >= - TargetMachineFeatures::kEigenExpectedTensorAlignment; - }; - - if (!is_aligned(lhs_shape) || !is_aligned(rhs_shape) || - !is_aligned(output_shape)) { - return false; - } - - return true; -} - -bool PotentiallyImplementedAsEigenDot( - const HloInstruction& hlo, - const TargetMachineFeatures& target_machine_features) { - // For certain types of Dot, we can call Eigen - if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - if (ShapeUtil::IsZeroElementArray(lhs_shape) || - ShapeUtil::IsZeroElementArray(rhs_shape)) { - return false; - } - - if (ProfitableToImplementDotInTiledLlvmIr(hlo)) { - return false; - } - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape(), - target_machine_features)) { - const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers(); - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), - rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); - return true; - } - } - - return false; -} - // For vector-matrix dot products, it is always profitable to make the Rhs // column major. absl::optional ProfitableToMakeDotOperandColumnMajor( @@ -1652,16 +746,319 @@ absl::optional ProfitableToMakeDotOperandColumnMajor( return {}; } -bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) { +namespace { +// Return whether the given shape is rank 2. +bool IsRank2(const Shape& shape) { return shape.rank() == 2; } + +bool IsSimpleLayout(const Layout& layout) { + return layout.tiles().empty() && layout.format() == DENSE; +} + +// In a gemm operation where output = lhs * rhs, check whether the given shapes +// are valid for the operation. +bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, + const Shape& output_shape, + const TargetMachineFeatures& target_machine_features) { + CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout())) + << lhs_shape.DebugString(); + CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout())) + << rhs_shape.DebugString(); + CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout())) + << output_shape.DebugString(); + + switch (output_shape.element_type()) { + case F64: + case F32: + case F16: + return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape); + default: + return false; + } +} + +bool IsAlignedGemm(const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features) { + if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) || + ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) { + return false; + } + + return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape, + dot_info.result_shape, target_machine_features); +} + +bool CanEmitTiledLlvmIrGemm( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features) { + CHECK(IsAlignedGemm(dot_info, target_machine_features)); + + if (ShouldUseMultiThreadedEigen(config)) { + return false; + } + + int m = dot_info.result_shape.dimensions(0); + int k = dot_info.lhs_shape.dimensions( + dot_info.dim_nums.lhs_contracting_dimensions(0)); + int n = dot_info.result_shape.dimensions(1); + + if (!options::ForceEnableExperimentalLlvmIrGemm(config)) { + // TODO(sanjoy): We should make these numbers micro-arch specific. + bool small_gemm = + k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32)); + if (!small_gemm) { + return false; + } + } + + bool lhs_non_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 0; + bool rhs_non_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 1; + + if (lhs_non_canonical || rhs_non_canonical) { + return false; + } + + if (dot_info.result_shape.element_type() == F16) { + // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL + // adding this comment NFC. + return false; + } + + return true; +} + +DotImplementationStrategy GetDotImplementationStrategy( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features) { + PrimitiveType element_type = dot_info.result_shape.element_type(); // Any Matrix-Vector product of floating point or integral type, or // a transpose-dot fusion of the same can be lowered to a tiled LLVM // IR implementation. - const Shape& shape = dot.shape(); - return shape.dimensions_size() == 2 && - (shape.dimensions(0) == 1 || shape.dimensions(1) == 1) && - (primitive_util::IsFloatingPointType(shape.element_type()) || - primitive_util::IsIntegralType(shape.element_type())); + if (dot_info.result_shape.dimensions_size() == 2 && + (dot_info.result_shape.dimensions(0) == 1 || + dot_info.result_shape.dimensions(1) == 1) && + (primitive_util::IsFloatingPointType(element_type) || + primitive_util::IsIntegralType(element_type))) { + return DotImplementationStrategy::kTiledLlvmIrGemv; + } + + if (IsAlignedGemm(dot_info, target_machine_features)) { + return CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features) + ? DotImplementationStrategy::kTiledLlvmIrGemm + : DotImplementationStrategy::kEigen; + } + + return DotImplementationStrategy::kNaiveLlvmIr; +} + +Status EmitNonBatchDotOperation( + DotInfo dot_info, string hlo_name, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) { + PrimitiveType type = target_array.GetShape().element_type(); + TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type || + C128 == type); + DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name), + target_array, lhs_array, rhs_array, addend_array, + executable_run_options_value, b, hlo_module_config, + target_machine_features); + return dot_emitter.Emit(); +} + +Shape DropFirstDim(const Shape& shape) { + absl::Span array_shape_dims(shape.dimensions()); + array_shape_dims.remove_prefix(1); + return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), + array_shape_dims); +} + +Shape CollapseFirstNDims(const Shape& shape, int64 n) { + absl::Span input_shape_dims(shape.dimensions()); + int64 prefix_dim = + std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n, + 1ll, std::multiplies()); + DimensionVector result_dims; + result_dims.push_back(prefix_dim); + std::copy(input_shape_dims.begin() + n, input_shape_dims.end(), + std::back_inserter(result_dims)); + return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), + result_dims); +} + +llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b, + const llvm_ir::IrArray& array, int64 n) { + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + const Shape& shape = array.GetShape(); + CHECK(shape.has_layout() && + LayoutUtil::IsMonotonicWithDim0Major(shape.layout())); + CHECK_GE(shape.dimensions_size(), n); + Shape new_shape = CollapseFirstNDims(shape, n); + llvm::Value* new_value = b->CreateBitCast( + array.GetBasePointer(), + llvm_ir::ShapeToIrType(new_shape, module)->getPointerTo()); + return llvm_ir::IrArray(new_value, std::move(new_shape)); +} + +Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) { + // Checks some invariants that do not hold in general, but DotDecomposer + // should have established for us. This is just a debugging aid. + TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1); + std::vector batch_dim_numbers(dim_numbers.lhs_batch_dimensions_size()); + absl::c_iota(batch_dim_numbers, 0); + TF_RET_CHECK( + absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions())); + TF_RET_CHECK( + absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions())); + return Status::OK(); +} + +// Slice out the inner array at batch index `batch_index` from `outer_array`. +llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array, + llvm::Value* batch_index, + llvm::IRBuilder<>* b) { + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + + Shape inner_shape = DropFirstDim(outer_array.GetShape()); + llvm_ir::IrArray::Index slice_index(b->getInt64Ty()); + slice_index.push_back(batch_index); + slice_index.InsertAt( + /*index=*/1, outer_array.GetShape().dimensions_size() - 1, + b->getInt64(0)); + llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b); + llvm::Type* slice_ptr_type = + llvm_ir::ShapeToIrType(inner_shape, module)->getPointerTo(); + return llvm_ir::IrArray(b->CreateBitCast(slice_ptr, slice_ptr_type), + std::move(inner_shape)); +} + +Status EmitBatchDotOperation( + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) { + TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers())); + + // Lower a batch dot into a sequence of non-batch dot operations. + + int64 num_batch_dims = + dot.dot_dimension_numbers().lhs_batch_dimensions_size(); + + // First reshape the inputs to make sure we only have one batch dimension. + // This is a no-op bitcast because the operands have to be in row-major layout + // (enforced in CpuLayoutAssignment), and the batch dimensions are the leading + // dimensions (established by DotDecomposer and checked by + // ValidateDotDimensionNumbers above). + llvm_ir::IrArray lhs_array_reshaped = + CollapseFirstNDims(b, lhs_array, num_batch_dims); + llvm_ir::IrArray rhs_array_reshaped = + CollapseFirstNDims(b, rhs_array, num_batch_dims); + llvm_ir::IrArray target_array_reshaped = + CollapseFirstNDims(b, target_array, num_batch_dims); + + int64 batch_count = lhs_array_reshaped.GetShape().dimensions(0); + + KernelSupportLibrary ksl(b); + + return ksl.ForWithStatus( + llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count, + /*step=*/1, [&](llvm::Value* indvar) { + DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers(); + adjusted_dim_numbers.clear_lhs_batch_dimensions(); + adjusted_dim_numbers.clear_rhs_batch_dimensions(); + + // Create a DotInfo representing the "inner" non-batch dot operation. + DotInfo dot_info; + dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape()); + dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape()); + dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape()); + dot_info.dim_nums = dot.dot_dimension_numbers(); + dot_info.dim_nums.clear_lhs_batch_dimensions(); + dot_info.dim_nums.clear_rhs_batch_dimensions(); + + dot_info.dim_nums.set_lhs_contracting_dimensions( + 0, + dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims); + dot_info.dim_nums.set_rhs_contracting_dimensions( + 0, + dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims); + + llvm_ir::IrArray lhs_slice = + SliceOutInnerArray(lhs_array_reshaped, /*batch_index=*/indvar, b); + llvm_ir::IrArray rhs_slice = + SliceOutInnerArray(rhs_array_reshaped, /*batch_index=*/indvar, b); + llvm_ir::IrArray target_slice = SliceOutInnerArray( + target_array_reshaped, /*batch_index=*/indvar, b); + + // Emit the inner non-batch dot operation. + return EmitNonBatchDotOperation( + dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr, + executable_run_options_value, b, hlo_module_config, + target_machine_features); + }); +} + +bool IsBatchDot(const HloInstruction& instr) { + if (auto* dot_instr = DynCast(&instr)) { + return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0; + } + + return false; +} +} // namespace + +bool DotImplementationCanHandleTranspose( + const HloInstruction& dot_instr, + const TargetMachineFeatures& target_machine_features) { + DotImplementationStrategy impl_strategy = + GetDotImplementationStrategy(dot_instr.parent()->parent()->config(), + DotInfo(dot_instr), target_machine_features); + + // TODO(sanjoy): This is not quite right, it should be `impl_strategy == + // kEigen || impl_strategy == kTiledLlvmIrGemv || impl_strategy == + // kNaiveLlvmIr` but I'll fix this in a later CL in the interest of keeping + // the CL adding this comment NFC. + return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm || + impl_strategy == DotImplementationStrategy::kEigen; } +bool DotOperandsAndResultMustHaveRowMajorLayout( + const HloInstruction& dot_instr, + const TargetMachineFeatures& target_machine_features) { + DotImplementationStrategy impl_strategy = + GetDotImplementationStrategy(dot_instr.parent()->parent()->config(), + DotInfo(dot_instr), target_machine_features); + + return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm || + impl_strategy == DotImplementationStrategy::kEigen; +} + +Status EmitDotOperation(const HloInstruction& dot, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) { + // This routine assumes that the dot operation is not in a parallelized + // enclosing computation. + CHECK(dot.parent()->root_instruction()->outer_dimension_partitions().empty()); + + if (IsBatchDot(dot)) { + TF_RET_CHECK(addend_array == nullptr); + return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array, + executable_run_options_value, b, + hlo_module_config, target_machine_features); + } + + return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array, + lhs_array, rhs_array, addend_array, + executable_run_options_value, b, + hlo_module_config, target_machine_features); +} } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 4c2041b556aa8bf8fe8fb8e0674c0f4f04f0acae..105bd3005c86d87443b2528eba7b0106ad70590e 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -30,9 +30,16 @@ limitations under the License. namespace xla { namespace cpu { +// Returns true if the two operands and the output of `dot_instr` must have row +// major layout. +bool DotOperandsAndResultMustHaveRowMajorLayout( + const HloInstruction& dot_instr, + const TargetMachineFeatures& target_machine_features); -bool PotentiallyImplementedAsEigenDot( - const HloInstruction& hlo, +// Returns true our lowering strategy for `dot_instr` can fold in transposes to +// the either of the inputs. +bool DotImplementationCanHandleTranspose( + const HloInstruction& dot_instr, const TargetMachineFeatures& target_machine_features); // Returns the index for an operand to `hlo` that should ideally be column @@ -41,129 +48,24 @@ bool PotentiallyImplementedAsEigenDot( absl::optional ProfitableToMakeDotOperandColumnMajor( const HloInstruction& hlo); -// Returns true to indicate that we can generate a tiled LLVM IR implementation -// for |dot|. -bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot); - -// Helper class for emitting LLVM IR to perform the dot operation. -class DotOpEmitter { - public: - // Emit LLVM IR to perform the dot operation on lhs_array and rhs_array and - // place the result in target_array. IR is emitted at current insert point of - // the builder. Upon completion of the method, the insert point is set to the - // end of all instructions emitted for this operation. - // - // If `addend_array` is not nullptr then it must be an array of the same - // dimensions as the result, and the result is computed as `addend_array` + - // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported - // for Matrix-vector products. - static Status EmitDotOperation( - const HloInstruction& dot, const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features); - - private: - DotOpEmitter(const HloInstruction& dot, const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features); - - // Emits the IR to perform the dot operation. - Status Emit(); - - // Emits instructions to perform a scalar dot product (a multiply of the - // LHS and RHS) and store the results in the target. - Status EmitScalarDot(); - - // Emit an LLVM IR implementation of the dot operation if we can. Returns - // true if an LLVM IR implementation was emitted. - bool EmitLlvmIrDotIfProfitable(); - - // Emits a call to the CPU runtime to perform the matrix multiply. - Status EmitCallToRuntime(); - - // Represents the dimensions of a matrix-matrix multiply operation. - struct MatMultDims { - // The number of rows in the LHS. - int64 m; - - // The number of columns in the LHS, which is also must be equal to the - // number of rows in the RHS. - int64 k; - - // The number of columns on the RHS. - int64 n; - - // True if the LHS matrix is column major. - bool lhs_column_major; - - // True if the LHS contraction dimension is not 1. - bool lhs_non_canonical; - - // True if the RHS matrix is column major. - bool rhs_column_major; - - // True if the RHS contraction dimension is not 0. - bool rhs_non_canonical; - - // True if the result matrix is column major. - bool target_column_major; - }; - - // Get the MatMultDims instance for the dot product this DotOpEmitter - // represents. Precondition: the dot is of rank 2 (and thus its operands are - // of rank 2 as well). - MatMultDims GetMatMultDims() const; - - bool EmitSmallGemmIfProfitable(const MatMultDims& mat_mult_dims); - - // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector - // registers. - int64 GetGemvTilingFactor() const { - const int64 kDefaultTilingFactor = 8; - return options::LlvmIrGemvTilingFactor(hlo_module_config_) - .value_or(kDefaultTilingFactor); - } - - std::tuple GetGemmTileSize() const { - // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz - // - // TODO(b/80093688): Tune for other architectures and centralize this - // information in one place. - const std::tuple kDefaultTileSize = - std::tuple(11, 9, 1); - return options::LlvmIrGemmTileSize(hlo_module_config_) - .value_or(kDefaultTileSize); - } - - // Returns true if we should use an experimental implementation of GEMM - // (general matrix matrix multiplication) if possible. - bool EnableExperimentalLlvmIrGemm() const { - return options::EnableExperimentalLlvmIrGemm(hlo_module_config_); - } - - // Returns true if we should call into multi-threaded Eigen routines. - bool ShouldUseMultiThreadedEigen() { - return hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); - } - - const HloInstruction& dot_; - const llvm_ir::IrArray& target_array_; - const llvm_ir::IrArray& lhs_array_; - const llvm_ir::IrArray& rhs_array_; - const llvm_ir::IrArray* addend_array_; - llvm::Value* executable_run_options_value_; - llvm::IRBuilder<>* b_; - const HloModuleConfig& hlo_module_config_; - const TargetMachineFeatures& target_machine_features_; -}; - +// Emit LLVM IR to perform the dot operation on lhs_array and rhs_array and +// place the result in target_array. IR is emitted at current insert point of +// the builder. Upon completion of the method, the insert point is set to the +// end of all instructions emitted for this operation. +// +// If `addend_array` is not nullptr then it must be an array of the same +// dimensions as the result, and the result is computed as `addend_array` + +// dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported +// for Matrix-vector products. +Status EmitDotOperation(const HloInstruction& dot, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features); } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter_internal.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..cc28918ed60a8086135846e2b9b1b9d75ec31ef6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter_internal.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_INTERNAL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_INTERNAL_H_ + +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +// ----------------------------------------------------------------------------- +// INTERNAL HEADER. +// +// This file exposes internal implementation details from dot_op_emitter.cc for +// unit tests. Please do not depend on this! +// +// ----------------------------------------------------------------------------- + +namespace xla { +namespace cpu { +namespace internal { + +// Represents a dot operation. We use this in lieu of an `HloInstruction` +// because we want to be able to create this for the "inner" dot operation in a +// batch dot, for which there is no separate HLO instruction. +struct DotInfo { + Shape lhs_shape; + Shape rhs_shape; + Shape result_shape; + DotDimensionNumbers dim_nums; + + explicit DotInfo(const HloInstruction& instr) { + CHECK_EQ(instr.opcode(), HloOpcode::kDot); + lhs_shape = instr.operand(0)->shape(); + rhs_shape = instr.operand(1)->shape(); + result_shape = instr.shape(); + dim_nums = instr.dot_dimension_numbers(); + } +}; + +// Dictates how a dot operation is implemented. +enum class DotImplementationStrategy { + // The dot operation is lowered into LLVM IR that implements a naive nested + // loop that computes the result one element at a time. This is our + // "fallback"; we don't really want this to kick in for any non-trival dot + // operation. + kNaiveLlvmIr, + + // The dot operation is lowered into LLVM IR that implements a tiled + // Matrix*Vector operation. This strategy also allows fusing in a bias add + // into the dot. The matrix can be row major or column major, both are + // supported. + kTiledLlvmIrGemv, + + // The dot operation is lowered into LLVM IR that implemetns a tiled + // Matrix*Matrix operation. No fusions are supported. The two inputs + // and the output have to be row major. + kTiledLlvmIrGemm, + + // The dot operation is lowered into a call into an Eigen routine. No fusions + // are supported today. The two inputs and the output have to be row major. + // However, we do allow transposing either the LHS or the RHS as part of the + // GEMM -- we expose this flexibility as flexibility in the contraction + // dimensions, but we can also see this as flexibility in the input layouts. + kEigen, +}; + +// Returns the implementation strategy for a dot with the configuration +// `dot_info`. +DotImplementationStrategy GetDotImplementationStrategy( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features); +} // namespace internal +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_INTERNAL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index c8312d80bd5012e5bcb42a410db18a7fa77a2eb6..0028fbaed895becad8da496aa8acdf7dc173a2a0 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -51,10 +51,11 @@ StatusOr CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, return Unimplemented("atan2"); } // Create a function declaration. - llvm::Function* function = - llvm::cast(module_->getOrInsertFunction( - llvm_ir::AsStringRef(function_name), lhs->getType(), lhs->getType(), - rhs->getType())); + llvm::Function* function = llvm::dyn_cast( + module_ + ->getOrInsertFunction(llvm_ir::AsStringRef(function_name), + lhs->getType(), lhs->getType(), rhs->getType()) + .getCallee()); function->setCallingConv(llvm::CallingConv::C); function->setDoesNotThrow(); function->setDoesNotAccessMemory(); @@ -85,9 +86,11 @@ StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, return Unimplemented("tanh"); } // Create a function declaration. - llvm::Function* function = llvm::cast( - module_->getOrInsertFunction(llvm_ir::AsStringRef(function_name), - value->getType(), value->getType())); + llvm::Function* function = llvm::dyn_cast( + module_ + ->getOrInsertFunction(llvm_ir::AsStringRef(function_name), + value->getType(), value->getType()) + .getCallee()); function->setCallingConv(llvm::CallingConv::C); function->setDoesNotThrow(); function->setDoesNotAccessMemory(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 1a8bedfe6afb4f096ddd4703c312b84d521a7ba5..a8b139aec9e96b6bb580baf74789df7c998cebf8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -26,7 +26,7 @@ namespace cpu { int64 GetMinimumAlignmentForArray( const Shape& shape, const TargetMachineFeatures& target_machine_features) { - CHECK(ShapeUtil::IsArray(shape)); + CHECK(shape.IsArray()); CHECK(!LayoutUtil::HasLayout(shape) || LayoutUtil::IsDense(shape.layout())); // We don't require a layout to be set on `shape`. This only works on CPU diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 4032c2da2f33ee61da8771ae6225a14172cbe6e8..2418d96440f9994842a54769cf6d561610ccfa18 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -24,11 +24,9 @@ limitations under the License. #include #include +// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/platform/logging.h" -// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -70,6 +68,8 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/logging.h" namespace xla { @@ -77,7 +77,6 @@ namespace { using llvm_ir::AsStringRef; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; -namespace gtl = tensorflow::gtl; } // namespace namespace cpu { @@ -87,7 +86,8 @@ IrEmitter::IrEmitter( llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - const TargetMachineFeatures* target_machine_features) + const TargetMachineFeatures* target_machine_features, + bool emit_code_for_msan) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), @@ -97,7 +97,8 @@ IrEmitter::IrEmitter( alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), is_top_level_computation_(false), - target_machine_features_(*target_machine_features) { + target_machine_features_(*target_machine_features), + emit_code_for_msan_(emit_code_for_msan) { b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() .xla_cpu_enable_fast_math())); @@ -111,10 +112,9 @@ IrEmitter::IrEmitter( StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order) { + absl::Span instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); - VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix - << "]; ordered? " << (instruction_order != nullptr); + VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]"; is_top_level_computation_ = is_top_level_computation; num_dynamic_loop_bounds_ = 0; if (!computation->root_instruction()->outer_dimension_partitions().empty()) { @@ -141,11 +141,7 @@ StatusOr IrEmitter::EmitComputation( bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || arch_type_ == llvm::Triple::ArchType::x86_64; profiling_state_ = ProfilingState(use_rdtscp); - if (instruction_order == nullptr) { - TF_RETURN_IF_ERROR(computation->Accept(this)); - } else { - TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order)); - } + TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order)); llvm::Function* ir_function = compute_function_->function(); InsertOrDie(&emitted_functions_, computation, ir_function); // Delete 'compute_function', finalizing 'ir_function' and restoring caller @@ -228,11 +224,11 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) { } Status IrEmitter::HandleCopy(HloInstruction* copy) { - if (ShapeUtil::IsTuple(copy->shape())) { + if (copy->shape().IsTuple()) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy)); return EmitMemcpy(*(copy->operand(0)), *copy); - } else if (ShapeUtil::IsArray(copy->shape())) { + } else if (copy->shape().IsArray()) { // Use the elemental emitter for array shapes. return DefaultAction(copy); } @@ -244,10 +240,12 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); DCHECK_GE(byte_size, 0); - // Largest scalar is a complex64 so we don't need to worry about the + // Largest scalar is a complex128 so we don't need to worry about the // int64->int truncation here. - DCHECK_LE(byte_size, 8); - return byte_size; + DCHECK_LE(byte_size, 16); + + // Allocations may be 8-byte aligned if part of a small block. + return std::min(8LL, byte_size); } int64 IrEmitter::ByteSizeOf(const Shape& shape) const { @@ -321,7 +319,7 @@ Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { auto on_false = tuple_select->operand(2); TF_RET_CHECK(pred->shape().element_type() == PRED); TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape())); - TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape())); + TF_RET_CHECK(tuple_select->shape().IsTuple()); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select)); llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred), GetEmittedValueFor(on_true), @@ -351,7 +349,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_, module_); - if (ShapeUtil::IsTuple(data_shape)) { + if (data_shape.IsTuple()) { TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape)); // For a tuple, we first copy each of the internal elements to @@ -415,11 +413,18 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Function* acquire_func; if (kind == XfeedKind::kInfeed) { - acquire_func = llvm::cast(module_->getOrInsertFunction( - runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type)); + acquire_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type) + .getCallee()); } else { - acquire_func = llvm::cast(module_->getOrInsertFunction( - runtime::kAcquireOutfeedBufferForPopulationSymbolName, acquire_type)); + acquire_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + runtime::kAcquireOutfeedBufferForPopulationSymbolName, + acquire_type) + .getCallee()); } acquire_func->setCallingConv(llvm::CallingConv::C); @@ -432,11 +437,19 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Function* release_func; if (kind == XfeedKind::kInfeed) { - release_func = llvm::cast(module_->getOrInsertFunction( - runtime::kReleaseInfeedBufferAfterDequeueSymbolName, release_type)); + release_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + runtime::kReleaseInfeedBufferAfterDequeueSymbolName, + release_type) + .getCallee()); } else { - release_func = llvm::cast(module_->getOrInsertFunction( - runtime::kReleaseOutfeedBufferAfterPopulationSymbolName, release_type)); + release_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + runtime::kReleaseOutfeedBufferAfterPopulationSymbolName, + release_type) + .getCallee()); } release_func->setCallingConv(llvm::CallingConv::C); @@ -475,7 +488,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { const Shape& operand_shape = operand->shape(); llvm::Value* value = GetEmittedValueFor(operand); - if (!ShapeUtil::IsTuple(operand_shape)) { + if (!operand_shape.IsTuple()) { return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value); } @@ -498,6 +511,27 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { const HloSortInstruction* sort = Cast(hlo); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); Shape keys_shape = sort->keys()->shape(); + PrimitiveType keys_type = keys_shape.element_type(); + switch (keys_type) { + case PRED: + case S8: + case U8: + case S16: + case U16: + case BF16: + case F16: + case S32: + case U32: + case F32: + case S64: + case U64: + case F64: + break; + default: + return Unimplemented( + "Element type %s not supported in the Sort op on CPU.", + PrimitiveType_Name(keys_type)); + } std::vector destination_addresses(sort->operand_count()); for (int64 i = 0; i < sort->operand_count(); ++i) { ShapeIndex shape_index = @@ -540,110 +574,52 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { higher_dimensions *= normalized_keys_shape.dimensions(i); } int64 lower_dimensions = 1; - for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1; + for (int64 i = normalized_keys_shape.rank() - 1; i > physical_dimension_to_sort; --i) { lower_dimensions *= normalized_keys_shape.dimensions(i); } - PrimitiveType keys_type = keys_shape.element_type(); - const char* fn_name = nullptr; - llvm::Type* keys_native_type = nullptr; - switch (keys_type) { - case PRED: - fn_name = runtime::kKeyValueSortPREDSymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case S8: - fn_name = runtime::kKeyValueSortS8SymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case U8: - fn_name = runtime::kKeyValueSortU8SymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case S16: - fn_name = runtime::kKeyValueSortS16SymbolName; - keys_native_type = b_.getInt16Ty()->getPointerTo(); - break; - case U16: - fn_name = runtime::kKeyValueSortU16SymbolName; - keys_native_type = b_.getInt16Ty()->getPointerTo(); - break; - case F16: - fn_name = runtime::kKeyValueSortF16SymbolName; - keys_native_type = b_.getHalfTy()->getPointerTo(); - break; - case S32: - fn_name = runtime::kKeyValueSortS32SymbolName; - keys_native_type = b_.getInt32Ty()->getPointerTo(); - break; - case U32: - fn_name = runtime::kKeyValueSortU32SymbolName; - keys_native_type = b_.getInt32Ty()->getPointerTo(); - break; - case F32: - fn_name = runtime::kKeyValueSortF32SymbolName; - keys_native_type = b_.getFloatTy()->getPointerTo(); - break; - case S64: - fn_name = runtime::kKeyValueSortS64SymbolName; - keys_native_type = b_.getInt64Ty()->getPointerTo(); - break; - case U64: - fn_name = runtime::kKeyValueSortU64SymbolName; - keys_native_type = b_.getInt64Ty()->getPointerTo(); - break; - case F64: - fn_name = runtime::kKeyValueSortF64SymbolName; - keys_native_type = b_.getDoubleTy()->getPointerTo(); - break; - default: - return Unimplemented( - "Element type %s not supported in the Sort op on CPU.", - PrimitiveType_Name(keys_type)); - } - + auto less_than_function = FindOrDie(emitted_functions_, sort->to_apply()); + CHECK(absl::c_binary_search(thread_local_computations_, sort->to_apply())); llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get( b_.getVoidTy(), - {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), + {b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(), - b_.getInt32Ty()->getPointerTo()}, + b_.getInt32Ty()->getPointerTo(), b_.getInt1Ty(), b_.getInt8PtrTy(), + b_.getInt64Ty()->getPointerTo(), less_than_function->getType()}, /*isVarArg=*/false); - auto* key_value_sort_func = llvm::cast( - module_->getOrInsertFunction(fn_name, key_value_sort_type)); + auto* key_value_sort_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction(runtime::kKeyValueSortSymbolName, + key_value_sort_type) + .getCallee()); key_value_sort_func->setCallingConv(llvm::CallingConv::C); key_value_sort_func->setDoesNotThrow(); - llvm::Value* values; - llvm::Value* sizes; - if (sort->values_count() == 0) { - values = llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()); - sizes = llvm::Constant::getNullValue(b_.getInt32Ty()->getPointerTo()); - } else { - values = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b_.getInt8PtrTy(), b_.getInt32(sort->values_count()), - "cc_values_alloca", &b_); - sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b_.getInt32Ty(), b_.getInt32(sort->values_count()), "cc_sizes_alloca", - &b_); - for (int64 i = 0; i < sort->values_count(); ++i) { - llvm::Value* value_as_i8ptr = - PointerCast(destination_addresses[i + 1], b_.getInt8PtrTy()); - llvm::Value* slot_in_values_alloca = - ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i); - Store(value_as_i8ptr, slot_in_values_alloca); - llvm::Value* slot_in_sizes_alloca = - ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i); - llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType( - sort->operand(i + 1)->shape().element_type())); - Store(size, slot_in_sizes_alloca); - } + llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca", + &b_); + llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca", + &b_); + for (int64 i = 0; i < sort->operand_count(); ++i) { + llvm::Value* value_as_i8ptr = + PointerCast(destination_addresses[i], b_.getInt8PtrTy()); + llvm::Value* slot_in_values_alloca = + ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i); + Store(value_as_i8ptr, slot_in_values_alloca); + llvm::Value* slot_in_sizes_alloca = + ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i); + llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType( + sort->operand(i)->shape().element_type())); + Store(size, slot_in_sizes_alloca); } Call(key_value_sort_func, - {PointerCast(destination_addresses[0], keys_native_type), - b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), + {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), b_.getInt64(lower_dimensions), values, - b_.getInt32(sort->values_count()), sizes}); + b_.getInt32(sort->operand_count()), sizes, + b_.getInt1(sort->is_stable()), GetExecutableRunOptionsArgument(), + GetProfileCountersArgument(), less_than_function}); if (sort->values_count() > 0) { llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_, @@ -752,11 +728,6 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( } Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { - TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( - /*instruction=*/*reduce_window, - /*operands=*/{reduce_window->operand(0)}, - /*supported_types=*/{F32, BF16, S32, F16})); - // Pseudo code for reduce window: // // for (coordinates O in the output) @@ -784,8 +755,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { const auto init_value = select_and_scatter->operand(2); const Window& window = select_and_scatter->window(); PrimitiveType operand_element_type = operand->shape().element_type(); - const int64 rank = ShapeUtil::Rank(operand->shape()); - CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); + const int64 rank = operand->shape().rank(); + CHECK_EQ(rank, source->shape().rank()); CHECK_EQ(rank, window.dimensions_size()); // TODO(b/31410564): Implement dilation for select-and-scatter. @@ -947,12 +918,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { auto rhs = dot->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*dot, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F16, F32, F64, C64})); + /*supported_types=*/{F16, F32, F64, C64, C128})); const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); - if (dnums.lhs_batch_dimensions_size() > 0 || - dnums.rhs_batch_dimensions_size() > 0) { - return Unimplemented("Dot with batch dimensions not implemented."); - } if (dnums.lhs_contracting_dimensions_size() != 1) { // This is disallowed by ShapeInference today. @@ -975,10 +942,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { << llvm_ir::DumpToString(*target_array.GetBasePointer()); // Dot operation is complicated so we delegate to a helper class. - return DotOpEmitter::EmitDotOperation( - *dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, - GetExecutableRunOptionsArgument(), &b_, hlo_module_config_, - target_machine_features_); + return EmitDotOperation(*dot, target_array, lhs_array, rhs_array, + /*addend_array=*/nullptr, + GetExecutableRunOptionsArgument(), &b_, + hlo_module_config_, target_machine_features_); } StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( @@ -1123,7 +1090,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { auto rhs = convolution->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F16, F32, C64})); + /*supported_types=*/{F16, F32, C64, C128})); // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support // different data layouts. @@ -1236,8 +1203,8 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded " "conv2d function."; } - llvm::Function* conv_func = llvm::cast( - module_->getOrInsertFunction(fn_name, conv_type)); + llvm::Function* conv_func = llvm::dyn_cast( + module_->getOrInsertFunction(fn_name, conv_type).getCallee()); conv_func->setCallingConv(llvm::CallingConv::C); conv_func->setDoesNotThrow(); conv_func->setOnlyAccessesArgMemory(); @@ -1320,8 +1287,8 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { ? runtime::kEigenFftSymbolName : runtime::kEigenSingleThreadedFftSymbolName; - llvm::Function* fft_func = llvm::cast( - module_->getOrInsertFunction(fn_name, fft_type)); + llvm::Function* fft_func = llvm::dyn_cast( + module_->getOrInsertFunction(fn_name, fft_type).getCallee()); fft_func->setCallingConv(llvm::CallingConv::C); fft_func->setDoesNotThrow(); fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); @@ -1338,11 +1305,11 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { return Status::OK(); } -Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { +Status IrEmitter::HandleAllReduce(HloInstruction* crs) { if (hlo_module_config_.replica_count() != 1) { // TODO(b/33011107): Support nontrivial cross replica sum on CPU. return Unimplemented( - "CrossReplicaSum with >1 replica is not implemented on CPU."); + "AllReduce with >1 replica is not implemented on CPU."); } // When there is a single replica, a cross replica sum is the identity @@ -1367,8 +1334,8 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { assignment_.GetUniqueSlice(crs, {i})); const Shape& operand_shape = crs->operand(i)->shape(); - CHECK(ShapeUtil::IsArray(operand_shape)) - << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + CHECK(operand_shape.IsArray()) + << "Operands to all-reduce must be arrays: " << crs->ToString(); operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); // TODO(b/63762267): Be more aggressive about specifying alignment. @@ -1404,7 +1371,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { int64 delta = 0; for (int64 i = 0; i < operand_shape.dimensions_size(); i++) { - if (reduced_dims.count(i)) { + if (reduced_dims.contains(i)) { delta++; } else { InsertOrDie(&unreduced_dim_map, i, i - delta); @@ -1417,7 +1384,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { for (int64 operand_dim_idx = 0; operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) { int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx); - if (!reduced_dims.count(operand_dim)) { + if (!reduced_dims.contains(operand_dim)) { if (FindOrDie(unreduced_dim_map, operand_dim) != result_shape.layout().minor_to_major(result_dim_idx++)) { return false; @@ -1714,10 +1681,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( vectorization_factor_in_bytes / ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()); - bool is_reduction_over_minor_dimension = - std::find(dimensions.begin(), dimensions.end(), - LayoutUtil::Minor(arg->shape().layout(), 0)) != - dimensions.end(); + bool is_reduction_over_minor_dimension = absl::c_linear_search( + dimensions, LayoutUtil::Minor(arg->shape().layout(), 0)); unsigned element_alignment = tensorflow::MathUtil::GCD( ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), @@ -1729,7 +1694,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( return false; } - CHECK(!ShapeUtil::IsTuple(reduce->shape())); + CHECK(!reduce->shape().IsTuple()); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce)); // We know we're not reducing over the most minor dimension, which means we @@ -1895,8 +1860,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( } Status IrEmitter::HandleReduce(HloInstruction* reduce) { - // TODO(b/112040122): Support variadic reduce. - if (!ShapeUtil::IsArray(reduce->shape())) { + // TODO(b/118333695): Support variadic reduce. + if (!reduce->shape().IsArray()) { return Unimplemented("Variadic reduce is not supported on CPU"); } auto arg = reduce->mutable_operand(0); @@ -1995,7 +1960,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // The memcpy will copy elements that are logically this shape (allowed to be // scalar). const Shape logical_element_shape = ShapeUtil::FilterDimensions( - [&inner_dims](int64 dim) -> bool { return inner_dims.count(dim); }, + [&inner_dims](int64 dim) { return inner_dims.contains(dim); }, operand->shape()); const int64 primitive_elements_per_logical_element = @@ -2210,10 +2175,10 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { llvm_ir::IrArray addend_array( GetIrArrayFor(fusion->operand(addend_param_number))); - TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *dot, target_array, lhs_array, rhs_array, &addend_array, - GetExecutableRunOptionsArgument(), &b_, hlo_module_config_, - target_machine_features_)); + TF_RETURN_IF_ERROR( + EmitDotOperation(*dot, target_array, lhs_array, rhs_array, + &addend_array, GetExecutableRunOptionsArgument(), &b_, + hlo_module_config_, target_machine_features_)); return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); @@ -2262,15 +2227,51 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { InBoundsGEP(operands_alloca, {b_.getInt64(i)}); Store(operand_as_i8ptr, slot_in_operands_alloca); } - auto* custom_call_ir_function = - llvm::cast(module_->getOrInsertFunction( - AsStringRef(custom_call_target), - llvm::FunctionType::get( - /*Result=*/b_.getVoidTy(), - /*Params=*/{i8_ptr_type, operands_alloca->getType()}, - /*isVarArg=*/false))); + if (emit_code_for_msan_) { + // Mark the alloca as initialized for msan. The buffer gets read by the + // custom callee, which might be msan-instrumented. + // TODO(b/66051036): Run the msan instrumentation pass instead. + const llvm::DataLayout& dl = module_->getDataLayout(); + llvm::Type* intptr_type = b_.getIntPtrTy(dl); + auto* msan_unpoison_ir_function = llvm::cast( + module_ + ->getOrInsertFunction( + "__msan_unpoison", + llvm::FunctionType::get( + /*Result=*/b_.getVoidTy(), + /*Params=*/{i8_ptr_type, intptr_type}, /*isVarArg=*/false)) + .getCallee()); + Call(msan_unpoison_ir_function, + {PointerCast(operands_alloca, i8_ptr_type), + llvm::ConstantInt::get( + intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)}); + } + auto* custom_call_ir_function = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + AsStringRef(custom_call_target), + llvm::FunctionType::get( + /*Result=*/b_.getVoidTy(), + /*Params=*/{i8_ptr_type, operands_alloca->getType()}, + /*isVarArg=*/false)) + .getCallee()); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); + // Write the tuple table if the output is a tuple. + if (custom_call->shape().IsTuple()) { + std::vector base_ptrs; + for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape()); + ++i) { + const Shape& elem_shape = + ShapeUtil::GetTupleElementShape(custom_call->shape(), i); + TF_RET_CHECK(!elem_shape.IsTuple()) << "Nested tuples not implemented"; + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueSlice(custom_call, {i})); + llvm::Value* addr = EmitBufferPointer(slice, elem_shape); + base_ptrs.push_back(addr); + } + llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_, module_); + } auto* output_address_arg = PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); @@ -2391,8 +2392,7 @@ StatusOr IrEmitter::EmitFastConcatenate( int64 concat_dim = concatenate->dimensions(0); const Layout& output_layout = output_shape.layout(); auto output_min2maj = LayoutUtil::MinorToMajor(output_layout); - auto concat_dim_layout_itr = - std::find(output_min2maj.begin(), output_min2maj.end(), concat_dim); + auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim); std::vector inner_dims(output_min2maj.begin(), concat_dim_layout_itr); std::vector outer_dims(std::next(concat_dim_layout_itr), @@ -2792,7 +2792,7 @@ llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); llvm::LoadInst* param_address_untyped = Load(param_address_offset); - if (!ShapeUtil::IsOpaque(target_shape)) { + if (!target_shape.IsOpaque()) { AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); AttachDereferenceableMetadataForLoad(param_address_untyped, target_shape); @@ -2851,7 +2851,9 @@ llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice, if (slice.allocation()->is_thread_local()) { return EmitThreadLocalBufferPointer(slice, target_shape); } else if (slice.allocation()->is_constant()) { - return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); + return BitCast( + FindOrDie(constant_buffer_to_global_, slice.allocation()->index()), + IrShapeType(target_shape)->getPointerTo()); } else { return EmitGlobalBufferPointer(slice, target_shape); } @@ -2944,8 +2946,7 @@ Status IrEmitter::ElementTypesSameAndSupported( TF_RET_CHECK(!operands.empty()); PrimitiveType primitive_type = operands[0]->shape().element_type(); - if (std::find(supported_types.begin(), supported_types.end(), - primitive_type) == supported_types.end()) { + if (!absl::c_linear_search(supported_types, primitive_type)) { return Unimplemented("unsupported operand type %s in op %s", PrimitiveType_Name(primitive_type), HloOpcodeString(instruction.opcode())); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 559a8162a2d53f28ea6817653503c216af90a610..0e372335f3aae919f9a9c559f86d4d61ab799b70 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -72,13 +72,15 @@ class IrEmitter : public DfsHloVisitorWithDefault, // index in the profiling array. // computation_to_profile_idx: the mapping from HLO computations to their // index in the profiling array. + // emit_code_for_msan: whether emitted code should be compatible with msan. IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - const TargetMachineFeatures* target_machine); + const TargetMachineFeatures* target_machine, + bool emit_code_for_msan); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -101,7 +103,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order); + absl::Span instruction_order); llvm::IRBuilder<>* b() { return &b_; } @@ -134,7 +136,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllReduce(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleSort(HloInstruction* sort) override; @@ -250,14 +252,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice, const Shape& target_shape); - // Emits a function into the current module. This can be used for - // computations embedded inside other computations, such as the - // function that a map operation applies. - StatusOr EmitFunction( - HloComputation* function, // The function to emit. - absl::string_view - function_name_suffix); // Used for LLVM IR register names. - // Emits a call to a thread local function (e.g. to the computation nested // within a reduce or a map). Thread local callees (by definition) only write // to and read from thread local allocations. @@ -448,7 +442,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, computation_to_profile_idx_; // Maps HLOs to Values emitted for them. - std::unordered_map emitted_value_; + absl::flat_hash_map emitted_value_; llvm_ir::AliasAnalysis alias_analysis_; @@ -582,6 +576,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, std::vector thread_local_computations_; std::vector global_computations_; + bool emit_code_for_msan_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index adfb8392bf6fa356f0a5cdab3ff74036eca8918e..84a5b058cfb11c899eb6ae03478ed550b84dc819 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -266,9 +266,11 @@ Status EmitCallToParallelForkJoin( /*Params=*/compute_function_params, /*isVarArg=*/false); - llvm::Function* fork_join_func = - llvm::cast(module->getOrInsertFunction( - runtime::kParallelForkJoinSymbolName, fork_join_type)); + llvm::Function* fork_join_func = llvm::dyn_cast( + module + ->getOrInsertFunction(runtime::kParallelForkJoinSymbolName, + fork_join_type) + .getCallee()); fork_join_func->setCallingConv(llvm::CallingConv::C); fork_join_func->setDoesNotThrow(); diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index f9722ffadac801521ddcbb568dd4435fd02e951b..93ef51754d21ad3ff4e24298c89649ef4c2742fb 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -36,57 +36,88 @@ const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX"; const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX"; namespace { -llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, - llvm::StringRef function_name, - int vector_width, - bool enable_fast_math) { - llvm::Function* vector_tanh_function = module->getFunction(function_name); - if (vector_tanh_function == nullptr) { + +// Replaces calls to the function `fn_name` with the code generated by +// fn_body_generator. +// +// We assume that fn_name accepts either a scalar f32 or a vector of +// vector_width f32s, and that fn_body_generator generates a function body with +// the same inputs/outputs as fn_name. +void RewriteCalls( + llvm::Module* module, const char* fn_name, + std::function* b, llvm::Value* input, + int32 vector_width)> + fn_body_generator, + int32 vector_width, bool enable_fast_math) { + llvm::Function* fn = module->getFunction(fn_name); + if (fn == nullptr) { // If the function declaration is not present in the module, there can't be // any calls to resolve. Don't emit the function in this case. - return nullptr; + return; } - llvm::LLVMContext* context = &module->getContext(); + // Our task is to generate a function body for `fn`, but we can't generate a + // function body for an LLVM intrinsic. So if fn is an intrinsic, replace it + // with a new function. + if (fn->isIntrinsic()) { + llvm::Function* new_fn = llvm::Function::Create( + fn->getFunctionType(), llvm::GlobalValue::InternalLinkage, + llvm::Twine("xla_impl.") + fn_name, module); + fn->replaceAllUsesWith(new_fn); + fn->eraseFromParent(); + fn = new_fn; + } - llvm::BasicBlock* vector_tanh_body = - llvm::BasicBlock::Create(*context, "body", vector_tanh_function); + llvm::LLVMContext* context = &module->getContext(); - llvm::IRBuilder<> b(vector_tanh_body); + llvm::BasicBlock* fn_body = llvm::BasicBlock::Create(*context, "body", fn); + llvm::IRBuilder<> b(fn_body); llvm::FastMathFlags fast_math_flags; fast_math_flags.setFast(enable_fast_math); b.setFastMathFlags(fast_math_flags); - llvm::Value* input = &*vector_tanh_function->arg_begin(); - CHECK_EQ(vector_width, input->getType()->getVectorNumElements()); - b.CreateRet(llvm_ir::EmitFastTanh(&b, input)); - - DCHECK(!llvm::verifyFunction(*vector_tanh_function)); - return vector_tanh_function; -} + llvm::Value* input = &*fn->arg_begin(); -llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, - llvm::StringRef function_name, - int vector_width, - bool enable_fast_math) { - llvm::Function* vector_exp_function = module->getFunction(function_name); - if (vector_exp_function == nullptr) { - // If the function declaration is not present in the module, there can't be - // any calls to resolve. Don't emit the function in this case. - return nullptr; + // Upcast to vector type if input is a scalar. + if (vector_width == 1) { + llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1); + input = b.CreateInsertElement(llvm::UndefValue::get(v1_type), input, + uint64_t{0}); } - llvm::LLVMContext* context = &module->getContext(); + // Generate the vectorized code. + CHECK_EQ(vector_width, input->getType()->getVectorNumElements()); + llvm::Value* result = fn_body_generator(&b, input, vector_width); + + // Downcast result to scalar type if necessary. + if (vector_width == 1) { + result = b.CreateExtractElement(result, uint64_t{0}); + } + b.CreateRet(result); + DCHECK(!llvm::verifyFunction(*fn)); - llvm::BasicBlock* vector_exp_body = - llvm::BasicBlock::Create(*context, "body", vector_exp_function); + // Force-inline `fn` into all of its callers and then delete `fn`. + // + // TODO(b/73081976): Should we avoid inlining these in some cases? + std::vector calls_to_inline; + for (auto* user : fn->users()) { + calls_to_inline.push_back(llvm::cast(user)); + } + for (auto* call_to_inline : calls_to_inline) { + llvm::InlineFunctionInfo inline_function_info; + CHECK(llvm::InlineFunction(call_to_inline, inline_function_info)); + } + fn->eraseFromParent(); +} - llvm::IRBuilder<> b(vector_exp_body); - llvm::FastMathFlags fast_math_flags; - fast_math_flags.setFast(); - b.setFastMathFlags(fast_math_flags); +llvm::Value* GenerateVF32Tanh(llvm::IRBuilder<>* b, llvm::Value* input, + int32 /*vector_width*/) { + return llvm_ir::EmitFastTanh(b, input); +} - VectorSupportLibrary vsl(F32, vector_width, &b, "exp_f32"); +llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input, + int32 vector_width) { + VectorSupportLibrary vsl(F32, vector_width, b, "exp_f32"); // This implements the same polynomial approximation as implemented in Eigen3. @@ -107,7 +138,6 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1); const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1); - llvm::Value* input = &*vector_exp_function->arg_begin(); llvm::Value* input_clamped = vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi); llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half)); @@ -128,49 +158,24 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, // VectorSupportLibrary (intentionally) can't juggle more than one type at a // time so drop down to IRBuilder for this bit. llvm::Value* vector_constant_0x7f = - b.CreateVectorSplat(vector_width, b.getInt32(0x7f)); + b->CreateVectorSplat(vector_width, b->getInt32(0x7f)); llvm::Value* vector_constant_23 = - b.CreateVectorSplat(vector_width, b.getInt32(23)); + b->CreateVectorSplat(vector_width, b->getInt32(23)); llvm::Type* i32_vector_type = - llvm::VectorType::get(b.getInt32Ty(), vector_width); + llvm::VectorType::get(b->getInt32Ty(), vector_width); // fx is clamped so we don't have to worry about it being out of range for // i32. - llvm::Value* emm0 = b.CreateFPToSI(fx, i32_vector_type); - emm0 = b.CreateAdd(emm0, vector_constant_0x7f); - emm0 = b.CreateShl(emm0, vector_constant_23); - llvm::Value* emm0_f32 = b.CreateBitCast(emm0, vsl.vector_type()); - - llvm::Value* result = vsl.Max(vsl.Mul(y, emm0_f32), input); + llvm::Value* emm0 = b->CreateFPToSI(fx, i32_vector_type); + emm0 = b->CreateAdd(emm0, vector_constant_0x7f); + emm0 = b->CreateShl(emm0, vector_constant_23); + llvm::Value* emm0_f32 = b->CreateBitCast(emm0, vsl.vector_type()); - b.CreateRet(result); - - DCHECK(!llvm::verifyFunction(*vector_exp_function)); - return vector_exp_function; + return vsl.Max(vsl.Mul(y, emm0_f32), input); } -llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, - llvm::StringRef function_name, - int vector_width, - bool enable_fast_math) { - llvm::Function* vector_log_function = module->getFunction(function_name); - if (vector_log_function == nullptr) { - // If the function declaration is not present in the module, there can't be - // any calls to resolve. Don't emit the function in this case. - return nullptr; - } - - llvm::LLVMContext* context = &module->getContext(); - - llvm::BasicBlock* vector_log_body = - llvm::BasicBlock::Create(*context, "body", vector_log_function); - - llvm::IRBuilder<> b(vector_log_body); - llvm::FastMathFlags fast_math_flags; - fast_math_flags.setFast(); - b.setFastMathFlags(fast_math_flags); - - llvm::Value* input = &*vector_log_function->arg_begin(); - VectorSupportLibrary vsl(F32, vector_width, &b, "log_f32"); +llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input, + int32 vector_width) { + VectorSupportLibrary vsl(F32, vector_width, b, "log_f32"); const llvm::APFloat half = GetIeeeF32(0.5); const llvm::APFloat one = GetIeeeF32(1.0); @@ -193,129 +198,107 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, // The smallest non denormalized float number. const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000); const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000); + const llvm::APFloat pos_inf = GetIeeeF32FromBitwiseRep(0x7f800000); const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000); // invalid_mask is set if x is negative or NaN (and therefore output // must be NaN). llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector()); - llvm::Value* iszero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector()); + llvm::Value* is_zero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector()); + llvm::Value* is_pos_inf_mask = vsl.FCmpEQMask(input, pos_inf); // Cut off denormalized stuff. - input = vsl.Max(min_norm_pos, input); + llvm::Value* tmp0 = vsl.Max(min_norm_pos, input); // VectorSupportLibrary (intentionally) can't juggle more than one type at a // time so drop down to IRBuilder for this bit. llvm::Value* vector_constant_0x7f = - b.CreateVectorSplat(vector_width, b.getInt32(0x7f)); + b->CreateVectorSplat(vector_width, b->getInt32(0x7f)); llvm::Value* vector_constant_23 = - b.CreateVectorSplat(vector_width, b.getInt32(23)); + b->CreateVectorSplat(vector_width, b->getInt32(23)); llvm::Type* i32_vector_type = - llvm::VectorType::get(b.getInt32Ty(), vector_width); + llvm::VectorType::get(b->getInt32Ty(), vector_width); - llvm::Value* emm0 = - b.CreateLShr(b.CreateBitCast(input, i32_vector_type), vector_constant_23); + llvm::Value* emm0 = b->CreateLShr(b->CreateBitCast(tmp0, i32_vector_type), + vector_constant_23); // Keep only the fractional part. - input = vsl.FloatAnd(input, inv_mant_mask); - input = vsl.FloatOr(input, half); + tmp0 = vsl.FloatAnd(tmp0, inv_mant_mask); + tmp0 = vsl.FloatOr(tmp0, half); - emm0 = b.CreateSub(emm0, vector_constant_0x7f); - llvm::Value* e = vsl.Add(one, b.CreateSIToFP(emm0, vsl.vector_type())); + emm0 = b->CreateSub(emm0, vector_constant_0x7f); + llvm::Value* e = vsl.Add(one, b->CreateSIToFP(emm0, vsl.vector_type())); // part2: // if( x < SQRTHF ) { // e -= 1; // x = x + x - 1.0; // } else { x = x - 1.0; } - llvm::Value* mask = vsl.FCmpOLTMask(input, cephes_SQRTHF); - llvm::Value* tmp = vsl.FloatAnd(input, mask); - input = vsl.Sub(input, one); + llvm::Value* mask = vsl.FCmpOLTMask(tmp0, cephes_SQRTHF); + llvm::Value* tmp1 = vsl.FloatAnd(tmp0, mask); + tmp0 = vsl.Sub(tmp0, one); e = vsl.Sub(e, vsl.FloatAnd(mask, one)); - input = vsl.Add(input, tmp); + tmp0 = vsl.Add(tmp0, tmp1); - llvm::Value* x2 = vsl.Mul(input, input); - llvm::Value* x3 = vsl.Mul(x2, input); + llvm::Value* x2 = vsl.Mul(tmp0, tmp0); + llvm::Value* x3 = vsl.Mul(x2, tmp0); llvm::Value *y, *y1, *y2; - y = vsl.MulAdd(input, cephes_log_p0, cephes_log_p1); - y1 = vsl.MulAdd(input, cephes_log_p3, cephes_log_p4); - y2 = vsl.MulAdd(input, cephes_log_p6, cephes_log_p7); - y = vsl.MulAdd(y, input, cephes_log_p2); - y1 = vsl.MulAdd(y1, input, cephes_log_p5); - y2 = vsl.MulAdd(y2, input, cephes_log_p8); + y = vsl.MulAdd(tmp0, cephes_log_p0, cephes_log_p1); + y1 = vsl.MulAdd(tmp0, cephes_log_p3, cephes_log_p4); + y2 = vsl.MulAdd(tmp0, cephes_log_p6, cephes_log_p7); + y = vsl.MulAdd(y, tmp0, cephes_log_p2); + y1 = vsl.MulAdd(y1, tmp0, cephes_log_p5); + y2 = vsl.MulAdd(y2, tmp0, cephes_log_p8); y = vsl.MulAdd(y, x3, y1); y = vsl.MulAdd(y, x3, y2); y = vsl.Mul(y, x3); y1 = vsl.Mul(cephes_log_q1, e); - tmp = vsl.Mul(half, x2); + llvm::Value* tmp2 = vsl.Mul(half, x2); y = vsl.Add(y, y1); - input = vsl.Sub(input, tmp); + tmp0 = vsl.Sub(tmp0, tmp2); y2 = vsl.Mul(cephes_log_q2, e); - input = vsl.Add(input, y); - input = vsl.Add(input, y2); + tmp0 = vsl.Add(tmp0, y); + tmp0 = vsl.Add(tmp0, y2); - // Negative arg will be NAN, 0 will be -INF. - llvm::Value* or_lhs = - vsl.FloatAndNot(iszero_mask, vsl.FloatOr(input, invalid_mask)); - llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf); - llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs); + // Contains +/-inf where +/-inf is the correct answer, otherwise 0. + llvm::Value* result_inf = vsl.FloatOr(vsl.FloatAnd(is_zero_mask, minus_inf), + vsl.FloatAnd(is_pos_inf_mask, pos_inf)); - b.CreateRet(result); + // Contains a finite result or nan. This is the correct answer only if both + // result_minus_inf and result_pos_inf are both 0. + // + // (This implementation works because 0xffffffff is a nan.) + llvm::Value* result_finite_or_nan = vsl.FloatOr(tmp0, invalid_mask); - DCHECK(!llvm::verifyFunction(*vector_log_function)); - return vector_log_function; + // Combine the above into a final result. + return vsl.FloatOr(result_inf, + vsl.FloatAndNot(vsl.FloatOr(is_zero_mask, is_pos_inf_mask), + result_finite_or_nan)); } } // namespace void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { - auto* tanh_v4f32 = - EmitVectorF32TanhIfNeeded(module, kTanhV4F32SymbolName, - /*vector_width=*/4, enable_fast_math); - auto* tanh_v8f32 = - EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName, - /*vector_width=*/8, enable_fast_math); - - auto* exp_v4f32 = - EmitVectorF32ExpIfNeeded(module, kExpV4F32SymbolName, - /*vector_width=*/4, enable_fast_math); - auto* exp_v8f32 = - EmitVectorF32ExpIfNeeded(module, kExpV8F32SymbolName, - /*vector_width=*/8, enable_fast_math); - - auto* log_v4f32 = - EmitVectorF32LogIfNeeded(module, kLogV4F32SymbolName, - /*vector_width=*/4, enable_fast_math); - auto* log_v8f32 = - EmitVectorF32LogIfNeeded(module, kLogV8F32SymbolName, - /*vector_width=*/8, enable_fast_math); - - // Gather all the call sites, force inline them and then delete the vector - // function bodies. - // - // TODO(b/73081976): Should we avoid inlining these intrinsics in some cases? - - std::vector calls_to_inline; - for (auto* function : - {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) { - if (function != nullptr) { - for (auto* user : function->users()) { - calls_to_inline.push_back(llvm::cast(user)); - } - } - } - - for (auto* call_to_inline : calls_to_inline) { - llvm::InlineFunctionInfo inline_function_info; - CHECK(llvm::InlineFunction(call_to_inline, inline_function_info)); - } - - for (auto* function : - {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) { - if (function != nullptr) { - function->eraseFromParent(); - } - } + // Curry some params to RewriteCalls. + auto rewrite_calls = + std::bind(RewriteCalls, module, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, enable_fast_math); + + rewrite_calls("tanhf", GenerateVF32Tanh, /*vector_width=*/1); + rewrite_calls("llvm.tanh.f32", GenerateVF32Tanh, /*vector_width=*/1); + rewrite_calls(kTanhV4F32SymbolName, GenerateVF32Tanh, /*vector_width=*/4); + rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8); + + rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1); + rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1); + rewrite_calls(kExpV4F32SymbolName, GenerateVF32Exp, /*vector_width=*/4); + rewrite_calls(kExpV8F32SymbolName, GenerateVF32Exp, /*vector_width=*/8); + + rewrite_calls("logf", GenerateVF32Log, /*vector_width=*/1); + rewrite_calls("llvm.log.f32", GenerateVF32Log, /*vector_width=*/1); + rewrite_calls(kLogV4F32SymbolName, GenerateVF32Log, /*vector_width=*/4); + rewrite_calls(kLogV8F32SymbolName, GenerateVF32Log, /*vector_width=*/8); } } // namespace runtime diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index f8441c3e345504616485c6b34b4302acd5cc23a3..a6f4273a5a70aab0bc88383283d2a55b1ecb1681 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -34,7 +34,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, llvm::Type* index_type) { CHECK_NE(index_type, nullptr); - CHECK(!ShapeUtil::IsTuple(shape_)); + CHECK(!shape_.IsTuple()); CHECK(!ShapeUtil::IsScalar(shape_)); llvm_ir::ForLoopNest loop_nest(loop_name, b_); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index ede7f433ca6b2cc5629115f800348be9dfb2b93b..6121d1ca9a5c785cedd947200d3e7e320aa06bc2 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -146,11 +146,9 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( (opcode == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction, target_machine_features_)) || - PotentiallyImplementedAsEigenDot(*instruction, - target_machine_features_) || (opcode == HloOpcode::kFusion && instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || - ShapeUtil::IsTuple(instruction->shape())) { + instruction->shape().IsTuple()) { return 1; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index f0b65046c14ccec5336abf7c4d05d1d755f783bd..35ae62b42dfa768c6abd0508097d6b235b2ebf54 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -112,10 +112,10 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { const string hlo_string = R"( HloModule TestTaskParallel_infeed_outfeed ENTRY InfeedOutfeed { - token = token[] after-all() - infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token) + token0 = token[] after-all() + infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token0) infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0 - ROOT outfeed0 = token[] outfeed(infeed0.data, token) + ROOT outfeed0 = token[] outfeed(infeed0.data, token0) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc index 2d9492eacfea34bec3b0f1115e171a5328b7cdc3..6f72ddadf94d4c5b9add2ee66e0f4ac9a8ae9099 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -69,8 +69,13 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( CHECK_EQ(params, nullptr); CHECK_GT(num_partitions, 1); CHECK_GT(num_partitioned_dims, 0); + CHECK_NE(function_ptr, nullptr); + CHECK_NE(partitions, nullptr); const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); + CHECK_NE(run_options, nullptr); + CHECK_NE(run_options->intra_op_thread_pool(), nullptr); + ComputeFunctionType function = reinterpret_cast(function_ptr); // Compute partition stride in 'partitions' array. diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index 722aa3120ef4d8c957873ac58c361f19632dde1f..70a6d0af02c0c2db7208db561cf29e35a74707b2 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -15,12 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" #include -#include #include -#include #include +#include #include -#include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/platform/dynamic_annotations.h" @@ -28,80 +26,15 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace { -using tensorflow::int16; using tensorflow::int32; using tensorflow::int64; -using tensorflow::int8; -using tensorflow::uint16; -using tensorflow::uint32; -using tensorflow::uint64; -using tensorflow::uint8; - -template -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements); -} - -// We would like a total order of floating point numbers so that the -// sort has a predictable behavior in the presence of NaNs. Rather -// than using floating point comparison, we use the following trick: -// If f is a float, and -// x = bit_cast(f); -// y = x < 0 ? 0x7FFFFFFF - x : x; -// then y is ordered as an int32 such that finite values have the -// obvious order, -0 is ordered before 0, and -NaN and NaN appear at -// the beginning and end of the ordering. -template -CastType Convert(KeyType value) { - CastType casted_value; - memcpy(&casted_value, &value, sizeof(CastType)); - if (casted_value < 0) { - return static_cast(std::numeric_limits::max()) - - casted_value; - } - return casted_value; -} - -template -bool LessThan(KeyType lhs, KeyType rhs) { - return Convert(lhs) < - Convert(rhs); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, rhs.first); - }); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, rhs.first); - }); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, - int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan( - Eigen::half_impl::half_to_float(lhs.first), - Eigen::half_impl::half_to_float(rhs.first)); - }); -} +} // namespace -template -void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, - int32 values_count, - int32* values_primitive_type_size_in_bytes) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( + int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes, bool is_stable, + char* run_options, int64* prof_counters, + void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)) { // 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT // code, so msan can't tell they are initialized. TF_ANNOTATE_MEMORY_IS_INITIALIZED(values, values_count * sizeof(char*)); @@ -121,8 +54,9 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, int64 num_iteration_elements = a * c; int64 sort_dimension_offset = c; - std::unique_ptr[]> row_to_sort( - new std::pair[sort_dimension_elements]); + std::unique_ptr indices(new int64[sort_dimension_elements]); + std::unique_ptr comparison_values(new char*[2 * values_count]); + std::iota(indices.get(), indices.get() + sort_dimension_elements, 0); std::unique_ptr reordered_values( new std::string[sort_dimension_elements]); for (int64 index = 0; index < num_iteration_elements; ++index) { @@ -135,24 +69,33 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, int64 base_offset = index % sort_dimension_offset + (index - index % sort_dimension_offset) * sort_dimension_elements; - // TODO(b/26783907): We could define a custom iterator class that references - // all arrays. Then we could avoid the intermediate copy. However this - // would become more complicated, and it is not clear if the benefit is high - // enough. - for (int64 i = 0; i < sort_dimension_elements; ++i) { - row_to_sort[i] = - std::make_pair(keys[base_offset + i * sort_dimension_offset], i); - } - KeyValueSort(row_to_sort.get(), sort_dimension_elements); - for (int64 i = 0; i < sort_dimension_elements; ++i) { - keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first; + auto compare_function = [&](int64 a, int64 b) -> bool { + int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + for (int32 i = 0; i < values_count; ++i) { + comparison_values[i * 2] = values[i] + memory_index_lhs; + comparison_values[i * 2 + 1] = values[i] + memory_index_rhs; + } + char result = 0; // Overwritten by less_than. + less_than(&result, run_options, comparison_values.get(), nullptr, + prof_counters); + return result != 0u; + }; + if (is_stable) { + std::stable_sort(indices.get(), indices.get() + sort_dimension_elements, + compare_function); + } else { + std::sort(indices.get(), indices.get() + sort_dimension_elements, + compare_function); } - // Reorder the values according to the order defined by the keys. + // Reorder the values according to the order defined by 'indices'. for (int32 idx = 0; idx < values_count; ++idx) { for (int64 i = 0; i < sort_dimension_elements; ++i) { int64 memory_index = - (base_offset + row_to_sort[i].second * sort_dimension_offset) * + (base_offset + indices[i] * sort_dimension_offset) * values_primitive_type_size_in_bytes[idx]; reordered_values[i] = @@ -168,88 +111,3 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, } } } -} // namespace - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED( - bool* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8( - int8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8( - uint8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16( - int16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16( - uint16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16( - Eigen::half* keys, int64 a, int64 b, int64 c, char** values, - int32 values_count, int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32( - int32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32( - uint32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32( - float* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64( - int64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64( - uint64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64( - double* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h index 7821099386969e855ea1737cf53ef49c15c6e93b..50c2911c3bd392b6df12717c34d250ce86ad26e0 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h @@ -21,76 +21,26 @@ limitations under the License. extern "C" { -// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b' -// dimension of 'keys' is sorted into ascending order. If 'values_count' is <= -// 0, 'values' and 'values_primitive_type_size_in_bytes' can be nullptr. -// If 'values_count' > 0, they contain exactly 'values_count' many elements. -// Each element of 'values' also represents a 3-dimensional shape with -// dimensions [a, b, c], and the size of the primitive type of the i-th shape -// has exactly 'values_primitive_type_size_in_bytes[i]' bytes. The elements in -// each 'values' shape are reordered in such a way that if the element at index -// 'i' in 'keys' was moved to index 'j', the element at index 'i' in a 'values' -// shape is also moved to index 'j' (which means that the same elements -// correspond to each other as before). -extern void __xla_cpu_runtime_KeyValueSortPRED( - bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, +// Each entry in 'values' represents a 3-dimensional shape with dimensions +// [a, b, c]. The 'b' dimension of each shape is sorted into ascending order +// according to the results of comparisons using the provided 'less_than' +// function. 'values_count' must be > 0 and specifies the number of entries in +// 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive +// type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]' +// bytes. 'is_stable' specifies whether the sorting should be stable. +// 'run_options' and 'prof_counters' are passed through to the less-than +// function, which expects the following arguments: +// - pointer to the return value buffer (char*) +// - xla::ExecutableRunOptions = 'run_options' (char*) +// - pointers to the parameter buffers (char**) +// - pointers to the buffer tables = nullptr for thread local functions (char**) +// - profile counters = 'prof_counters' (int64*) +extern void __xla_cpu_runtime_KeyValueSort( + tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS8( - tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU8( - tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS16( - tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU16( - tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF16( - Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS32( - tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU32( - tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF32( - float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, - char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS64( - tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU64( - tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF64( - double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, - char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); + tensorflow::int32* values_primitive_type_size_in_bytes, bool is_stable, + char* run_options, tensorflow::int64* prof_counters, + void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)); } #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index a71a85913cfef271bc2a226cb0cf2dd4204499a4..fe7e87a197b6cf571195537eaea2898659cd5e2e 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -23,12 +23,20 @@ limitations under the License. #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + using tensorflow::int32; using tensorflow::int64; namespace { -template +bool Is16BytesAligned(void* ptr) { + return reinterpret_cast(ptr) % 16 == 0; +} + +template void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { const xla::ExecutableRunOptions* run_options = @@ -46,11 +54,11 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, std::swap(rhs_rows, rhs_cols); } - const Eigen::TensorMap, Eigen::Aligned> A( - lhs, lhs_rows, lhs_cols); - const Eigen::TensorMap, Eigen::Aligned> B( - rhs, rhs_rows, rhs_cols); - Eigen::TensorMap, Eigen::Aligned> C(out, m, n); + const Eigen::TensorMap, Alignment> A(lhs, lhs_rows, + lhs_cols); + const Eigen::TensorMap, Alignment> B(rhs, rhs_rows, + rhs_cols); + Eigen::TensorMap, Alignment> C(out, m, n); typedef typename Eigen::Tensor::DimensionPair DimPair; int lhs_contract_dim = transpose_lhs ? 0 : 1; @@ -65,14 +73,24 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, } template -void MatMulImpl(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { +void MatMulDispatch(const void* run_options_ptr, T* out, T* lhs, T* rhs, + int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + bool all_buffers_16b_aligned = + Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs); + + if (!all_buffers_16b_aligned) { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); + return; + } + if (m == 1 || n == 1) { // Despite being single threaded, this version of matrix * vector is faster. xla::EigenMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } } @@ -82,20 +100,20 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16( const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); + MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32( const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); + MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64( const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); + MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index 16692e7f2e6145b2649b67987eef47916e958be2..1f7204e67a413efabd34cd7d88ced4c82ee7a5df 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -20,12 +20,20 @@ limitations under the License. #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + using tensorflow::int32; using tensorflow::int64; namespace { -template +bool Is16BytesAligned(void* ptr) { + return reinterpret_cast(ptr) % 16 == 0; +} + +template void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { int64 lhs_rows = m; @@ -40,11 +48,11 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, std::swap(rhs_rows, rhs_cols); } - const Eigen::TensorMap, Eigen::Aligned> A( - lhs, lhs_rows, lhs_cols); - const Eigen::TensorMap, Eigen::Aligned> B( - rhs, rhs_rows, rhs_cols); - Eigen::TensorMap, Eigen::Aligned> C(out, m, n); + const Eigen::TensorMap, Alignment> A(lhs, lhs_rows, + lhs_cols); + const Eigen::TensorMap, Alignment> B(rhs, rhs_rows, + rhs_cols); + Eigen::TensorMap, Alignment> C(out, m, n); typedef typename Eigen::Tensor::DimensionPair DimPair; int lhs_contract_dim = transpose_lhs ? 0 : 1; @@ -59,14 +67,22 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, } template -void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, - int64 m, int64 n, int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +void SingleThreadedMatMulDispatch(const void* run_options_ptr, T* out, T* lhs, + T* rhs, int64 m, int64 n, int64 k, + int32 transpose_lhs, int32 transpose_rhs) { + bool all_buffers_16b_aligned = + Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs); + + if (!all_buffers_16b_aligned) { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); + } + if (m == 1 || n == 1) { xla::EigenMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } } @@ -77,8 +93,8 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF16( const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); + SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, + n, k, transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void @@ -87,8 +103,8 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr, float* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); + SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void @@ -97,6 +113,6 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr, double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); + SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index efccadedf27181a4cddf4f1dc3610f7c6db1d821..f7b64738b7b314b56f4ae60336d9c85c90287219 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -116,13 +116,26 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, orc_jit_memory_mapper::GetInstance()); result.Resolver = symbol_resolver_; return result; + }, + /*NotifyLoaded=*/ + llvm::orc::LegacyRTDyldObjectLinkingLayer::NotifyLoadedFtor(), + /*NotifyFinalized=*/ + [this](VModuleKeyT, const llvm::object::ObjectFile& object, + const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { + this->NotifyObjectFinalized(object, object_info); + }, + /*NotifyFreed=*/ + [this](VModuleKeyT, const llvm::object::ObjectFile& object) { + this->NotifyObjectFreed(object); }), compile_layer_(object_layer_, CompilerFunctor(target_machine_.get(), &disassembler_, opt_level, optimize_for_size, enable_fast_math, disable_expensive_passes, std::move(pre_optimization_hook), - std::move(post_optimization_hook))) { + std::move(post_optimization_hook))), + gdb_jit_event_listener_( + llvm::JITEventListener::createGDBRegistrationListener()) { VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() << " features: " << target_machine_->getTargetFeatureString().str(); } @@ -139,7 +152,7 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { } if (func_addr == nullptr) { - VLOG(2) << "Unable to resolve runtime symbol: " << name; + LOG(ERROR) << "Unable to resolve runtime symbol: " << name; return nullptr; } llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), @@ -147,6 +160,20 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { return symbol_info; } +void SimpleOrcJIT::NotifyObjectFinalized( + const llvm::object::ObjectFile& object, + const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { + uint64_t key = static_cast( + reinterpret_cast(object.getData().data())); + gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info); +} + +void SimpleOrcJIT::NotifyObjectFreed(const llvm::object::ObjectFile& object) { + uint64_t key = static_cast( + reinterpret_cast(object.getData().data())); + gdb_jit_event_listener_->notifyFreeingObject(key); +} + SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule( std::unique_ptr module) { auto key = execution_session_.allocateVModule(); @@ -213,18 +240,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort); registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee)); registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee)); @@ -296,6 +312,9 @@ bool RegisterKnownJITSymbols() { REGISTER_LIBM_SYMBOL(sin, double (*)(double)); #ifdef __APPLE__ REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*)); + registry->Register("__sincosf_stret", + reinterpret_cast(__sincosf_stret)); + registry->Register("__sincos_stret", reinterpret_cast(__sincos_stret)); #else REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); #endif @@ -311,6 +330,18 @@ bool RegisterKnownJITSymbols() { registry->Register("memcpy", reinterpret_cast(memcpy)); registry->Register("memmove", reinterpret_cast(memmove)); registry->Register("memset", reinterpret_cast(memset)); + +#ifdef __APPLE__ + registry->Register("__bzero", reinterpret_cast(bzero)); + registry->Register("memset_pattern16", + reinterpret_cast(memset_pattern16)); +#endif + +#ifdef MEMORY_SANITIZER + registry->Register("__msan_unpoison", + reinterpret_cast(__msan_unpoison)); +#endif + return true; } diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 78406ba143570183aea09d79db3f9b708c21bf70..3307c2f93d796bbdcd49af7f68e9f6c388e402ca 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "llvm/ADT/Triple.h" +#include "llvm/ExecutionEngine/JITEventListener.h" #include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" @@ -99,6 +100,11 @@ class SimpleOrcJIT { private: llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name); + void NotifyObjectFinalized( + const llvm::object::ObjectFile& object, + const llvm::RuntimeDyld::LoadedObjectInfo& object_info); + void NotifyObjectFreed(const llvm::object::ObjectFile& object); + std::vector module_keys_; std::unique_ptr target_machine_; const Disassembler disassembler_; @@ -107,6 +113,15 @@ class SimpleOrcJIT { std::shared_ptr symbol_resolver_; ObjLayerT object_layer_; CompileLayerT compile_layer_; + + // Non owning pointer to a JIT event listener that registers the JIT events + // with an attached GDB. + // + // Note: we get a pointer to this event listener using + // `createGDBRegistrationListener` which makes it look like we're supposed to + // free this, but the function is poorly named and really just returns a + // pointer to a static object. + llvm::JITEventListener* gdb_jit_event_listener_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index f8f5f392da8ab3348e63185aecf7b639daacaa42..8b7f843582b697058fe328fe69990122d868ada4 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -16,7 +16,6 @@ limitations under the License. // Tests that we call into Eigen for dot operations as needed. #include -#include #include #include "absl/strings/str_cat.h" @@ -102,10 +101,10 @@ std::vector GetDotTestCases() { return result; } -INSTANTIATE_TEST_CASE_P(CpuEigenDotOperationTestInstantiation, - CpuEigenDotOperationTest, - ::testing::ValuesIn(GetDotTestCases()), - DotTestSpecToString); +INSTANTIATE_TEST_SUITE_P(CpuEigenDotOperationTestInstantiation, + CpuEigenDotOperationTest, + ::testing::ValuesIn(GetDotTestCases()), + DotTestSpecToString); } // namespace } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index 5cc6d01c0f15d4209cbc1fb259a0078fb9957f6e..f0f897e9635600b22e0c389ba056899e4d6ab3d4 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -48,7 +48,7 @@ class InfeedTest : public ClientLibraryTestBase { ASSERT_IS_OK(client_->TransferToInfeed(literal)); XlaBuilder builder(TestName()); Infeed(&builder, literal.shape()); - if (ShapeUtil::IsTuple(literal.shape())) { + if (literal.shape().IsTuple()) { // TODO(b/30609564): Use ComputeAndCompareLiteral instead. ComputeAndCompareTuple(&builder, literal, {}); } else { diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index 9b10c49f4f547edfb2164f98c49cceb031148bdc..9078b8fd1ff6cb0ddac89d5fcd13a9ccfae07763 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include #include +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" @@ -59,8 +59,9 @@ class CpuUnaryIntrinsicTest string features{spec.features.data(), spec.features.size()}; if (!features.empty()) { - std::replace_if(features.begin(), features.end(), - [](char c) { return c != '_' && !isalnum(c); }, '_'); + std::replace_if( + features.begin(), features.end(), + [](char c) { return c != '_' && !absl::ascii_isalnum(c); }, '_'); } else { features = ""; } @@ -140,10 +141,10 @@ IntrinsicTestSpec CpuUnaryIntrinsicTestCases[] = { HloOpcode::kLog, kTriple_android_arm, "", R"(CHECK: fadd fast <4 x float> )"}}; -INSTANTIATE_TEST_CASE_P(CpuUnaryIntrinsicTestInstantiation, - CpuUnaryIntrinsicTest, - ::testing::ValuesIn(CpuUnaryIntrinsicTestCases), - CpuUnaryIntrinsicTest::Name); +INSTANTIATE_TEST_SUITE_P(CpuUnaryIntrinsicTestInstantiation, + CpuUnaryIntrinsicTest, + ::testing::ValuesIn(CpuUnaryIntrinsicTestCases), + CpuUnaryIntrinsicTest::Name); } // namespace } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc index 3934c03a04c978009282b3cd0d39bacf9b12a356..762ee67db9a1b2a753c6ec5538dee1d13282942e 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc @@ -26,10 +26,16 @@ TEST_F(CpuKeyValueSortTest, SortR1) { const string hlo_text = R"( HloModule KeyValueSort +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY main { a = f32[10] parameter(0) - ROOT result = f32[10] sort(f32[10] a), dimensions={0} + ROOT result = f32[10] sort(f32[10] a), dimensions={0}, to_apply=compare } )"; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index fa0e09ff6b5694c0e97963b83c6e541b858a1376..0584c0484f810a03ccccd522163f54535440ef8b 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -31,29 +31,27 @@ HloModule RepeatedConstants while_body { arg_body = f32[2,3,2] parameter(0) ROOT const = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) } while_cond { arg_cond = f32[2,3,2] parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { param = f32[2,3,2] parameter(0) const_a = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body - token = token[] after-all() - out0 = token[] outfeed(f32[2,3,2] const_a, token[] token) - ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token) + token0 = token[] after-all() + out0 = token[] outfeed(f32[2,3,2] const_a, token[] token0) + ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token0) } )"; @@ -82,24 +80,24 @@ HloModule RepeatedConstants while_body { arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant(({ { 1 }, { 2 } }, {2} )) } while_cond { arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { param = f32[2,3,2] parameter(0) - const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + const_a = (f32[2,1]{1,0}, f32[1]{0}) constant(( { { 1 }, { 2 } }, {2} )) const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body - token = token[] after-all() - out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token) - ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token) + token0 = token[] after-all() + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token0) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token0) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index a7702c2aeeaff8a46a2c4f2785ccb873ea2c08e5..030bd41c2fc73eac41fe43c1acdf862d5dc97f98 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -75,8 +75,9 @@ TEST_F(CpuNoAliasTest, Concat) { // the buffers in the HLO module. We'll inspect these loads to ensure that // they have the expected alias information. llvm::Module ir_module("test", context); - llvm::Function* func = llvm::cast( - ir_module.getOrInsertFunction("test_fn", llvm::Type::getVoidTy(context))); + llvm::Function* func = llvm::dyn_cast( + ir_module.getOrInsertFunction("test_fn", llvm::Type::getVoidTy(context)) + .getCallee()); llvm::BasicBlock* bb = llvm::BasicBlock::Create(context, "body", func); llvm::IRBuilder<> b(bb); auto* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index e2c7af541eede5265f274c72f55305549f059839..aab7f0b393881642437f1891256bd138823a3b87 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -28,12 +28,11 @@ HloModule Outfeed ENTRY main { const_a = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) - token = token[] after-all() - outfeed = token[] outfeed(f32[2,3,2] const_a, token) + token0 = token[] after-all() + outfeed = token[] outfeed(f32[2,3,2] const_a, token0) ROOT root = () tuple() } )"; diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..9fc472ff767441e60cf618ac9022e5c50ea20023 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc @@ -0,0 +1,1073 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" + +#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +namespace cpu { +namespace { + +using tensorflow::int64; + +// Provides tiled access to an in-memory rank 2 array. +class MemoryTile { + public: + // Constructs a MemoryTile that can operate on tiles consisting of + // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at + // `major_dim_offset` in the major dimension. The tile size along the minor + // dimension is the vector size, and that is implicitly determined by `vsl`. + MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b, + llvm::Value* matrix, int64 matrix_size_along_minor_dim, + llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) + : vsl_(vsl), b_(b) { + pointers_.reserve(tile_size_along_major_dim); + for (int64 i = 0; i < tile_size_along_major_dim; i++) { + llvm::Value* total_offset = + b->CreateMul(b->getInt64(matrix_size_along_minor_dim), + b->CreateAdd(b->getInt64(i), major_dim_offset)); + pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); + } + } + + // Load a tile consisting of `tile_size_along_major_dim` vectors from position + // {major: `major_dim_offset`, minor: `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. + std::vector LoadTile(llvm::Value* minor_dim_offset) const { + std::vector result; + result.reserve(pointers_.size()); + for (const auto& pointer : pointers_) { + result.push_back(vsl_->LoadVector(pointer, minor_dim_offset)); + } + return result; + } + + // Stores `tile` to position {major: `major_dim_offset`, minor: + // `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. + void StoreTile(absl::Span tile, + llvm::Value* minor_dim_offset) const { + CHECK_EQ(tile.size(), pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); + } + } + + // Loads a tile of size [`tile_size_along_major_dim`, + // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, + // minor: `minor_dim_offset`} and then broadcasts each element into a vector + // of size vsl_.vector_size(). The (i,j)'th element of the return value is + // the (i,j)'th element in the tile broadcasted into an LLVM vector. + // + // Note: `major_dim_offset` is a parameter to the constructor. + std::vector> LoadBroadcastTile( + llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { + std::vector> result; + result.resize(pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + for (int64 j = 0; j < tile_size_along_middle_dim; j++) { + result[i].push_back(vsl_->LoadBroadcast( + pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j)))); + } + } + return result; + } + + private: + VectorSupportLibrary* vsl_; + llvm::IRBuilder<>* b_; + std::vector pointers_; +}; + +// The base class for the classes representing the GEMV emitter configurations. +// +// The IR emitted (modulo the LLVM values representing the input and output +// buffers) by the row major and column major GEMV emitters should be a function +// of their configuration. This is important because their configuration is +// used as a key to cache the generated IR. +class GemvConfig { + public: + // Mixin for convenience. + template + struct User { + public: + PrimitiveType scalar_type() const { + return derived().config().scalar_type(); + } + int64 tile_rows() const { return derived().config().tile_rows(); } + int64 tile_cols() const { return derived().config().tile_cols(); } + int64 m() const { return derived().config().m(); } + int64 k() const { return derived().config().k(); } + int64 has_addend() const { return derived().config().has_addend(); } + + private: + const T& derived() const { return *static_cast(this); } + }; + + PrimitiveType scalar_type() const { return scalar_type_; } + int64 tile_rows() const { return tile_rows_; } + int64 tile_cols() const { return tile_cols_; } + int64 m() const { return m_; } + int64 k() const { return k_; } + bool has_addend() const { return has_addend_; } + + string GetCacheKey() const { + return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_", + tile_rows(), "_", tile_cols(), "_", m(), "_", k(), + has_addend() ? "_with_addend" : ""); + } + + protected: + explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, bool has_addend) + : name_(std::move(name)), + scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + has_addend_(has_addend) {} + + private: + string name_; + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + bool has_addend_; +}; + +// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the +// layout of the vector does not matter). This implementation uses a tiling +// scheme to improve performance. +// +// We logically separate the LHS matrix into four segments: +// +// +----------------------+---+ +// | | | +// | | | +// | A | B | +// | | | +// | | | +// | | | +// +----------------------+---+ +// | C | D | +// +----------------------+---+ +// +// where A is the largest submatrix of the LHS that can be evenly dividied into +// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: +// +// +---+---+---+---+ +--+--+--+--+ +// |M00|M10|M20|M30| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M01|M11|M21|M31| and |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M02|M12|M22|M32| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M03|M13|M23|M33| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// +// (Legend: rows are horizontal and columns are vertical; and each column is one +// llvm::Value of a vector type) +// +// where: +// +// a. The left tile is from the column major left matrix. +// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3] +// vector loaded from the RHS vector. +// +// As we iterate through the column dimension, we compute the change to the +// result vector by an elementwise multiplication between the two tiles above +// followed by a reduction along the major dimension: +// +// +-----------------------------------+ +// | M00*V0 + M10*V1 + M20*V2 + M30*V3 | +// +-----------------------------------+ +// | M01*V0 + M11*V1 + M21*V2 + M31*V3 | +// Result[R:R+4] += +-----------------------------------+ +// | M02*V0 + M12*V1 + M22*V2 + M32*V3 | +// +-----------------------------------+ +// | M03*V0 + M13*V1 + M23*V2 + M33*V3 | +// +-----------------------------------+ +// +// Where R is the starting row for the tile. +// +// We have an inner epilogue loop to deal with the "C" submatrix and an outer +// epilogue loop to deal with the B,D submarix. +// +// TODO(sanjoy): We should investigate if using gather loads and scatter stores +// can be used here have the same inner loop for both column-major and row-major +// matrix-vector products. +class ColumnMajorMatrixVectorProductEmitter + : public GemvConfig::User { + public: + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"col_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, + llvm::IRBuilder<>* b) + : config_(config), + lhs_(lhs), + rhs_(rhs), + addend_(addend), + result_(result), + b_(b), + ksl_(b_), + vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") { + CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); + CHECK(!has_addend() || addend != nullptr); + } + + void Emit(); + + const Config& config() const { return config_; } + + private: + void EmitOuterLoopBody(llvm::Value* column, int64 column_count, + bool is_first_column); + + MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { + return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/m(), + /*major_dim_offset=*/column_start, + /*tile_size_along_major_dim=*/column_count); + } + + // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous + // sequence of `count` values, each one broadcasted to the vector width. + std::vector LoadRhsTile(llvm::Value* offset, int64 count) { + llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); + std::vector result; + result.reserve(count); + for (int64 i = 0; i < count; i++) { + result.push_back(vsl_.LoadBroadcast(base_pointer, i)); + } + return result; + } + + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, + const std::vector& rhs_tile, + int64 columns, bool is_first_column); + + void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, + bool is_first_tiled_column); + + Config config_; + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* addend_; + llvm::Value* result_; + llvm::IRBuilder<>* b_; + KernelSupportLibrary ksl_; + VectorSupportLibrary vsl_; +}; + +void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( + llvm::Value* column, int64 column_count, bool is_first_column) { + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column, + /*column_count=*/column_count); + + std::vector rhs_tile = + LoadRhsTile(column, /*count=*/column_count); + EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile, + /*columns=*/column_count, is_first_column); + EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); +} + +void ColumnMajorMatrixVectorProductEmitter::Emit() { + // See the comment on the class declaration for the algorithm used here. + int64 column_remainder = k() % tile_cols(); + int64 column_limit = k() - column_remainder; + + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols(), is_first_column); + }); + + if (column_remainder != 0) { + EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder, + column_limit == 0); + } +} + +void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( + MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, + int64 columns, bool is_first_column) { + int64 row_limit = m() - (m() % tile_rows()); + + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, + /*step=*/tile_rows(), [&](llvm::Value* row) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); + llvm::Value* accumulator = + is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) + : vsl_.GetZeroVector()) + : vsl_.LoadVector(result_, row); + for (int i = 0; i < columns; i++) { + accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); + } + vsl_.StoreVector(accumulator, result_, row); + }); +} + +void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( + llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { + int64 row_start = m() - (m() % tile_rows()); + if (row_start == m()) { + return; + } + + llvm::Value* columns_llvm = b_->getInt64(columns); + + // for (col = current_tile_col; col < (columns + current_tile_col); col++) + // for (row = row_start, row < m_; row++) { + // result[row] += lhs[row, col] * rhs[col] + // // Also take into account that if col is 0 then result[row] is not + // // initialized. + // } + + ksl_.For( + "dot.inner.epilg.outer", /*start=*/current_tile_col, + /*end=*/b_->CreateAdd(columns_llvm, current_tile_col), + /*step=*/1, /*peel_first_iteration=*/false, + [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { + llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); + llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m())); + llvm::Value* lhs_base_pointer = + vsl_.ComputeOffsetPointer(lhs_, total_offset); + ksl_.For( + "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), + /*step=*/1, [&](llvm::Value* scalar_row) { + llvm::Value* product = vsl_.Mul( + vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); + llvm::Value* setting_result_first_time = b_->CreateAnd( + is_first_scalar_col, b_->getInt1(is_first_tiled_column)); + ksl_.If( + setting_result_first_time, + /*true_block_generator=*/ + [&]() { + if (addend_) { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(addend_, scalar_row), + product), + result_, scalar_row); + } else { + vsl_.StoreScalar(product, result_, scalar_row); + } + }, + /*false_block_generator=*/ + [&]() { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), + result_, scalar_row); + }); + }); + }); +} + +// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the +// layout of the vector does not matter). This implementation uses a tiling +// scheme to improve performance. +// +// We logically separate the LHS matrix into four segments: +// +// +----------------------+---+ +// | | | +// | | | +// | A | B | +// | | | +// | | | +// | | | +// +----------------------+---+ +// | C | D | +// +----------------------+---+ +// +// where A is the largest submatrix of the LHS that can be evenly dividied into +// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: +// +// +---+---+---+---+ +// |M00|M10|M20|M30| +// +---+---+---+---+ +--+--+--+--+ +// |M01|M11|M21|M31| and |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M02|M12|M22|M32| +// +---+---+---+---+ +// |M03|M13|M23|M33| +// +---+---+---+---+ +// +// (Legend: rows are horizontal and columns are vertical; and each row is one +// llvm::Value of a vector type) +// +// where: +// +// a. The left tile is loaded from the row major left matrix. +// b. The right vector is loaded from the RHS vector. +// +// We keep 4 vector accumulators accumulating the following four vector +// expressions as we iterate over the row dimension: +// +// +------+------+------+------+ +// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4) +// +------+------+------+------+ +// +// In the end we do a horizontal reduction over these 4 vector accumulators to +// get 4 values in the result vector. +// +// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer +// epilogue loop to deal with the C,D submatrix. +class RowMajorMatrixVectorProductEmitter + : public GemvConfig::User { + public: + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"row_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b) + : config_(config), + lhs_(lhs), + rhs_(rhs), + addend_(addend), + result_(result), + b_(b), + ksl_(b_), + vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") { + CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); + CHECK(!has_addend() || addend != nullptr); + } + + void Emit(); + + const Config& config() const { return config_; } + + private: + MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { + return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/k(), + /*major_dim_offset=*/row_start, + /*tile_size_along_major_dim=*/row_count); + } + + void EmitOuterLoopBody(llvm::Value* row, int64 row_count); + + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows, + std::vector* vector_accumulators); + + void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, + std::vector* scalar_accumulators); + + Config config_; + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* addend_; + llvm::Value* result_; + llvm::IRBuilder<>* b_; + KernelSupportLibrary ksl_; + VectorSupportLibrary vsl_; +}; + +void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, + int64 row_count) { + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row, + /*row_count=*/row_count); + std::vector vector_accumulators; + std::vector scalar_accumulators; + for (int i = 0; i < row_count; i++) { + vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); + scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); + } + EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count, + &vector_accumulators); + EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, + &scalar_accumulators); + + std::vector accumulator_values; + std::transform( + vector_accumulators.begin(), vector_accumulators.end(), + std::back_inserter(accumulator_values), + [](const VectorVariable& vector_var) { return vector_var.Get(); }); + + std::vector horizontal_sums; + if (row_count == vsl_.vector_size()) { + if (addend_) { + horizontal_sums = vsl_.ComputeHorizontalSums( + std::move(accumulator_values), vsl_.LoadVector(addend_, row)); + } else { + horizontal_sums = + vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + } else { + horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + + for (int i = 0; i < row_count; i++) { + llvm::Value* result_value = + vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); + llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row); + if (addend_ && row_count != vsl_.vector_size()) { + result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); + } + vsl_.StoreScalar(result_value, result_, offset); + } +} + +void RowMajorMatrixVectorProductEmitter::Emit() { + // See the comment on the class declaration for the algorithm used here. + int64 row_remainder = m() % tile_rows(); + int64 row_limit = m() - row_remainder; + + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); + + if (row_remainder != 0) { + EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder); + } +} + +void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( + MemoryTile* lhs_memory_tile, int64 rows, + std::vector* vector_accumulators) { + int64 column_limit = k() - (k() % tile_cols()); + + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols(), [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set( + vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); +} + +void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( + llvm::Value* current_tile_row, int64 rows, + std::vector* scalar_accumulators) { + int64 column_start = k() - (k() % tile_cols()); + if (column_start == k()) { + return; + } + + for (int r = 0; r < rows; r++) { + llvm::Value* total_offset = b_->CreateMul( + b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k())); + llvm::Value* lhs_base_pointer = + vsl_.ComputeOffsetPointer(lhs_, total_offset); + ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), + /*step=*/1, [&](llvm::Value* scalar_col) { + llvm::Value* product = + vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), + vsl_.LoadScalar(rhs_, scalar_col)); + llvm::Value* old_value = (*scalar_accumulators)[r].Get(); + (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); + }); + } +} + +// This class implements a tiled matrix multiplication algorithm, intended for +// multiplying small matrices that don't need cache tiling. +// +// In the future this can be used as the innermost GEBP loop in a GEMM kernel as +// described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of +// high-performance matrix multiplication." ACM Transactions on Mathematical +// Software (TOMS) 34.3 (2008): 12.". +// +// This only supports canonical dot operations (i.e. where the lhs contraction +// dimension is 1 and the rhs contraction dimension is 0) over row major +// matrices. +class TiledSmallGemmEmitter { + public: + // Describe the dimensions of the kernel. + class Dimensions { + public: + explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} + + int64 m() const { return m_; } + int64 k() const { return k_; } + int64 n() const { return n_; } + + string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); } + + private: + const int64 m_; + const int64 k_; + const int64 n_; + }; + + // Represents the configuration of the emitter. The LLVM IR emitted by the + // emitter, modulo the LLVM values holding the input and output buffers, must + // be a function of the instance of `Config` passed to it. + // + // `dims` holds the matrix multiplication dimensions. + // + // `max_vectorization_width` is the maximum vector width (i.e. the width of + // the largest vector register we will use). This can be larger than the + // largest vector register supported by the machine -- LLVM will legalize + // these large vector widths into legally sized vectors. + // + // `max_vector_count` is the maximum number of vectors of size + // `max_vectorization_width` that we will attempt to process at once. + // + // `min_vectorization_width` is the smallest vector width the emitter will use + // -- below that it will devolve to using a scalar loop. + // + // The innermost reduction loop executes the matrix multiply in tiles of size + // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, + // ] in the RHS. + class Config { + public: + explicit Config(PrimitiveType scalar_type, Dimensions dims, + int64 max_vectorization_width, int64 max_vector_count, + int64 min_vectorization_width, int64 tile_size_m, + int64 tile_size_k) + : scalar_type_(scalar_type), + dims_(dims), + max_vectorization_width_(max_vectorization_width), + max_vector_count_(max_vector_count), + min_vectorization_width_(min_vectorization_width), + tile_size_m_(tile_size_m), + tile_size_k_(tile_size_k) {} + + string GetCacheKey() const { + return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_", + dims().ToString(), "_", max_vectorization_width(), + "_", min_vectorization_width(), "_", tile_size_m(), + "_", tile_size_k()); + } + + PrimitiveType scalar_type() const { return scalar_type_; } + Dimensions dims() const { return dims_; } + int64 max_vectorization_width() const { return max_vectorization_width_; } + int64 max_vector_count() const { return max_vector_count_; } + int64 min_vectorization_width() const { return min_vectorization_width_; } + + int64 tile_size_m() const { return tile_size_m_; } + int64 tile_size_k() const { return tile_size_k_; } + + private: + PrimitiveType scalar_type_; + Dimensions dims_; + int64 max_vectorization_width_; + int64 max_vector_count_; + int64 min_vectorization_width_; + int64 tile_size_m_; + int64 tile_size_k_; + }; + + // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies + // `lhs` with `rhs` and stores the result in `result`. + explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* b) + : lhs_(lhs), + rhs_(rhs), + result_(result), + config_(config), + b_(b), + ksl_(b_) { + CHECK(max_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(max_vectorization_width()))); + CHECK_GT(max_vector_count(), 0); + CHECK(min_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(min_vectorization_width()))); + CHECK_GE(max_vectorization_width(), min_vectorization_width()); + CHECK_GT(tile_size_k(), 0); + } + + void Emit(); + + private: + // The HandleResiduesOnX helpers split the iteration space for dimension X + // into a multiple of the tile size on dimension X and an epilogue. These + // helpers ultimately call into `EmitTiledGemm` for emitting the + // tiled GEMM kernel. + + void HandleResiduesOnN(); + void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start, + llvm::Value* n_end); + void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end); + + // This emits a tiled GEMM kernel. For a detailed description see the comment + // on the implementation. + void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, + llvm::Value* m_end); + + llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); } + + Config config() const { return config_; } + Dimensions dims() const { return config().dims(); } + + int64 max_vectorization_width() const { + return config().max_vectorization_width(); + } + int64 max_vector_count() const { return config().max_vector_count(); } + int64 min_vectorization_width() const { + return config().min_vectorization_width(); + } + int64 tile_size_m() const { return config().tile_size_m(); } + int64 tile_size_k() const { return config().tile_size_k(); } + PrimitiveType scalar_type() const { return config().scalar_type(); } + + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* result_; + Config config_; + + llvm::IRBuilder<>* b_; + KernelSupportLibrary ksl_; +}; + +void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); } + +void TiledSmallGemmEmitter::HandleResiduesOnN() { + // We can only iterate the `n` dimension for an extent that is divisible by + // the vectorization width. So we emit an outer loop that first processes the + // largest extent in `n` that is divisible by max_vectorization_width, then + // the largest remaining extent that is divisible by max_vectorization_width / + // 2 etc. + + int64 current_vectorization_width = + max_vector_count() * max_vectorization_width(); + int64 current_vector_count = max_vector_count(); + + int64 n_start = 0; + while (n_start != dims().n() && + current_vectorization_width >= min_vectorization_width()) { + int64 n_end = dims().n() - (dims().n() % current_vectorization_width); + if (n_start != n_end) { + VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_, + "gemm"); + HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); + n_start = n_end; + } + if (current_vector_count == 1) { + current_vectorization_width /= 2; + } else { + current_vector_count--; + current_vectorization_width = + current_vector_count * max_vectorization_width(); + } + } + + if (n_start != dims().n()) { + VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); + ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { + llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); + HandleResiduesOnK(&vsl, n_i, n_i_next); + }); + } +} + +void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { + int64 k_start = 0; + int64 k_end = dims().k() - (dims().k() % tile_size_k()); + if (k_end != k_start) { + HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), + n_start, n_end); + k_start = k_end; + } + + if (k_start != dims().k()) { + HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start), + GetInt64(dims().k()), n_start, n_end); + } +} + +void TiledSmallGemmEmitter::HandleResiduesOnM( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { + const int64 m_end = dims().m() - dims().m() % tile_size_m(); + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), + GetInt64(0), GetInt64(m_end)); + + if (m_end != dims().m()) { + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, + dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); + } +} + +// The loop structure is: +// +// Iterate over dimension M as m: +// Iterate over dimension N as n: +// Iterate over dimension K as k: +// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) +// +// I.e. a just a tiled version of a "naive" GEMM. +// +// The tiling scheme is as follows: +// +// Let the LHS be: +// +// +----+----+----+ +// | a0 | b0 | c0 | . +// +----+----+----+ . +// | a1 | b1 | c1 | . +// +----+----+----+ +// .. .. +// +// and the RHS be: +// +// +----+----+----+----+ +// | p0 | p1 | p2 | p3 | . +// +----+----+----+----+ . +// | q0 | q1 | q2 | q3 | . +// +----+----+----+----+ +// | r0 | r1 | r2 | r3 | . +// +----+----+----+----+ . +// ...... ...... +// +// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted +// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] +// matrix that we can increment the result matrix by. +// +// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank +// 3 array, L, of dimension [2,3,4]: +// +// L[0,_,_] * L[1,_,_] +// * +// +----+----+----+----+ * +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | +// +----+----+----+----+ * +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | +// +----+----+----+----+ * +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | +// +----+----+----+----+ * +----+----+----+----+ +// +// +// Then we FMA L[0,_,_] with the RHS to get the first row of the result and +// L[1,_,_] with the RHS to get the second row of the result. For example, +// L[0,_,_] is computed as: +// +// +----+----+----+----+ +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | +// +----+----+----+----+ +----+----+----+----+ +// +// to get: +// +// +-------------------+-------------------+-------------------+--------- +// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... +// +-------------------+-------------------+-------------------+--------- +void TiledSmallGemmEmitter::EmitTiledGemm( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { + ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { + MemoryTile result_memory_tile(vsl, b_, /*matrix=*/result_, + /*matrix_size_along_minor_dim=*/dims().n(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/dims().k(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + ksl_.For( + "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { + TileVariable result_tile_var(vsl, result_memory_tile.LoadTile(n_i)); + ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { + MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i, + tile_size_k); + std::vector> lhs_tile = + lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); + std::vector rhs_tile = rhs_memory_tile.LoadTile(n_i); + std::vector result_tile = result_tile_var.Get(); + for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { + for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { + result_tile[r_m_i] = + vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], + result_tile[r_m_i]); + } + } + result_tile_var.Set(result_tile); + }); + + result_memory_tile.StoreTile(result_tile_var.Get(), n_i); + }); + }); +} + +llvm::Type* GetPointerToElementType(llvm::Type* pointer_type) { + llvm::Type* type = + llvm::cast(pointer_type)->getElementType(); + while (auto* array_type = llvm::dyn_cast(type)) { + type = array_type->getElementType(); + } + + return type->getPointerTo(); +} + +struct GemvBuffersWithCanonicalType { + llvm::Value* lhs_canonicalized; + llvm::Value* rhs_canonicalized; + llvm::Value* addend_canonicalized; + llvm::Value* result_canonicalized; +}; + +GemvBuffersWithCanonicalType GetGemvBuffersWithCanonicalType( + llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b) { + // We characterize a GEMV operation via M and K, since N is implicitly 1. + // This means the GEMV that multiplies (say) [5,6] with [6,1] is implemented + // by the same GEMV that multiplies [5,6] with [1,6]. However, the + // `llvm::Types` for the inputs to the two GEMVs don't match (in a trivial + // sense -- the in memory representations are the same) since they're computed + // from the `xla::Shape`s. Since we want to be able to call the same + // `llvm::Function` for the two GEMVs we canonicalize the types of the GEMV + // inputs here into the same type. + GemvBuffersWithCanonicalType buffers_with_canonical_type; + llvm::Type* lhs_type = lhs->getType(); + llvm::Type* rhs_type = rhs->getType(); + llvm::Type* addend_type = addend ? addend->getType() : nullptr; + llvm::Type* result_type = result->getType(); + + buffers_with_canonical_type.lhs_canonicalized = + b->CreateBitCast(lhs, GetPointerToElementType(lhs_type)); + buffers_with_canonical_type.rhs_canonicalized = + b->CreateBitCast(rhs, GetPointerToElementType(rhs_type)); + buffers_with_canonical_type.addend_canonicalized = + addend ? b->CreateBitCast(addend, GetPointerToElementType(addend_type)) + : nullptr; + buffers_with_canonical_type.result_canonicalized = + b->CreateBitCast(result, GetPointerToElementType(result_type)); + + return buffers_with_canonical_type; +} + +} // namespace + +void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size) { + RowMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, + /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); + + GemvBuffersWithCanonicalType canonical_inputs = + GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), + canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized, + canonical_inputs.addend_canonicalized, + canonical_inputs.result_canonicalized, + [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* addend, + llvm::Value* result) { + RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, + result, b); + emitter.Emit(); + }); +} + +void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size) { + ColumnMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, + /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); + + GemvBuffersWithCanonicalType canonical_inputs = + GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), + canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized, + canonical_inputs.addend_canonicalized, + canonical_inputs.result_canonicalized, + [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* addend, + llvm::Value* result) { + ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, + result, b); + emitter.Emit(); + }); +} + +void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n, + int64 max_vectorization_width, int64 max_vector_count, + int64 min_vectorization_width, int64 tile_size_m, + int64 tile_size_k, llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size) { + TiledSmallGemmEmitter::Config config( + /*scalar_type=*/scalar_type, + TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, + /*max_vectorization_width=*/max_vectorization_width, + /*max_vector_count=*/max_vector_count, + /*min_vectorization_width=*/min_vectorization_width, + /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, + rhs, result, + [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result) { + TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs, + /*rhs=*/rhs, + /*result=*/result, b); + small_gemm_emitter.Emit(); + }); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..0a82326cc3704bce8c122261383249c60eda1f3a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace cpu { + +// These routines emit LLVM IR implementing tiled GEMM and GEMV routines. + +void EmitRowMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows, + tensorflow::int64 tile_cols, tensorflow::int64 m, + tensorflow::int64 k, llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* addend, llvm::Value* result, + llvm::IRBuilder<>* b, bool enable_fast_math, + bool optimize_for_size); + +void EmitColumnMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows, + tensorflow::int64 tile_cols, tensorflow::int64 m, + tensorflow::int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size); + +void EmitSmallGemm(PrimitiveType scalar_type, tensorflow::int64 m, + tensorflow::int64 k, tensorflow::int64 n, + tensorflow::int64 max_vectorization_width, + tensorflow::int64 max_vector_count, + tensorflow::int64 min_vectorization_width, + tensorflow::int64 tile_size_m, tensorflow::int64 tile_size_k, + llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* b, bool enable_fast_math, + bool optimize_for_size); + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 5690d2be2fe3e21c96b51a5226e0b29148217fd1..c444fd7d4aa88fa21b1aa2b2f058bd689b234b15 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -114,6 +114,9 @@ class VectorSupportLibrary { // raison d'etre) less cluttered. llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FCmpEQMask(llvm::Value* lhs, const llvm::APFloat& rhs) { + return FCmpEQMask(lhs, GetConstantFloat(lhs->getType(), rhs)); + } llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) { diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index e84bf00153aa28df29d8df486b92654feab4afbf..2f7fddb96da2dbb4e3f824daa483d5bcd027460f 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -103,11 +103,19 @@ class DfsHloVisitorBase { virtual Status HandlePower(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } + virtual Status HandleSqrt(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleRsqrt(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; - virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; + virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0; + virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; + virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0; virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 80ea5be298aea44a0f424398da74c4e478f10346..341bb37b8355e9987a0331d0a66bb8fe87f019cf 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -91,7 +91,10 @@ class DfsHloVisitorWithDefaultBase Status HandleFft(HloInstructionPtr fft) override { return DefaultAction(fft); } - Status HandleCrossReplicaSum(HloInstructionPtr crs) override { + Status HandleTriangularSolve(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } + Status HandleAllReduce(HloInstructionPtr crs) override { return DefaultAction(crs); } Status HandleAllToAll(HloInstructionPtr hlo) override { @@ -100,6 +103,9 @@ class DfsHloVisitorWithDefaultBase Status HandleCollectivePermute(HloInstructionPtr hlo) override { return DefaultAction(hlo); } + Status HandleReplicaId(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc index 825e1436f0ec6d49b555e5e3e9c2c7a19fb7b062..70173d43d79e931b75f131ad380ad98359cc78b8 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc @@ -73,15 +73,14 @@ ENTRY TestComputation { abs = f32[] abs(arg) add = f32[] add(arg, gte) broadcast = f32[42] broadcast(add), dimensions={} - slice = f32[0] slice(broadcast), slice={[1:2]} + slice = f32[1] slice(broadcast), slice={[1:2]} copy = f32[] copy(arg) eq = pred[] equal-to(arg, gte) neg = f32[] negate(arg) ROOT convert = f64[] convert(f32[] arg) })"; std::unique_ptr module = - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()) - .ConsumeValueOrDie(); + ParseAndReturnVerifiedModule(hlo_string).ConsumeValueOrDie(); ElementwiseTestVisitor visitor; TF_EXPECT_OK(module->entry_computation()->Accept(&visitor)); } diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index b2ba2617902104bfea06713332fa1c2aedea536d..559b9c1f2c9f341293ca89adc61e3312fd9f313c 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -156,29 +158,192 @@ Status DecomposeBatchDot(HloInstruction* dot) { return computation->ReplaceInstruction(dot, new_dot); } +// Convert a dot into a canonical form where non-contracting and contracting +// dimensions are reshaped together and batch dimensions are the most major +// dimensions. The requires transposing and reshapes the lhs and rhs and +// reshaping the output batch to the original shape. +Status CanonicalizeDot(HloInstruction* original_dot) { + auto computation = original_dot->parent(); + const auto& original_dnums = original_dot->dot_dimension_numbers(); + const int64 num_batch_dims = original_dnums.lhs_batch_dimensions_size(); + const int64 num_contracting_dims = + original_dnums.lhs_contracting_dimensions_size(); + + const auto& lhs_shape = original_dot->operand(0)->shape(); + const int64 lhs_rank = lhs_shape.rank(); + const int64 num_lhs_non_contracting_dims = + lhs_rank - num_batch_dims - num_contracting_dims; + + std::vector lhs_non_contracting_dims; + lhs_non_contracting_dims.reserve(num_lhs_non_contracting_dims); + int64 lhs_contracting_size = 1; + int64 lhs_non_contracting_size = 1; + std::vector batch_dim_sizes; + batch_dim_sizes.reserve(num_batch_dims); + for (int64 i = 0; i < lhs_rank; ++i) { + if (absl::c_linear_search(original_dnums.lhs_contracting_dimensions(), i)) { + lhs_contracting_size *= lhs_shape.dimensions(i); + } else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(), + i)) { + batch_dim_sizes.push_back(lhs_shape.dimensions(i)); + } else { + lhs_non_contracting_dims.push_back(i); + lhs_non_contracting_size *= lhs_shape.dimensions(i); + } + } + // The canonical form of the lhs is + // [BatchDims, NonContractingDims, ContractingsDims] + std::vector lhs_transpose; + lhs_transpose.reserve(lhs_rank); + lhs_transpose.insert(lhs_transpose.end(), + original_dnums.lhs_batch_dimensions().begin(), + original_dnums.lhs_batch_dimensions().end()); + lhs_transpose.insert(lhs_transpose.end(), lhs_non_contracting_dims.begin(), + lhs_non_contracting_dims.end()); + lhs_transpose.insert(lhs_transpose.end(), + original_dnums.lhs_contracting_dimensions().begin(), + original_dnums.lhs_contracting_dimensions().end()); + HloInstruction* transposed_lhs = + computation->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions(InversePermutation(lhs_transpose), + lhs_shape), + original_dot->mutable_operand(0), lhs_transpose)); + std::vector lhs_reshape_dims = batch_dim_sizes; + lhs_reshape_dims.push_back(lhs_non_contracting_size); + lhs_reshape_dims.push_back(lhs_contracting_size); + // Reshape the contracting and non-contracting dimensions together. + HloInstruction* reshaped_lhs = + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims), + transposed_lhs)); + + const auto& rhs_shape = original_dot->operand(1)->shape(); + const int64 rhs_rank = rhs_shape.rank(); + const int64 num_rhs_non_contracting_dims = + rhs_rank - num_batch_dims - num_contracting_dims; + std::vector rhs_non_contracting_dims; + rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims); + int64 rhs_non_contracting_size = 1; + int64 rhs_contracting_size = 1; + for (int64 i = 0; i < rhs_rank; ++i) { + if (absl::c_linear_search(original_dnums.rhs_contracting_dimensions(), i)) { + rhs_contracting_size *= rhs_shape.dimensions(i); + } else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(), + i)) { + rhs_non_contracting_dims.push_back(i); + rhs_non_contracting_size *= rhs_shape.dimensions(i); + } + } + + // The canonical form of the rhs is + // [BatchDims, ContractingsDims, NonContractingDims] + std::vector rhs_transpose; + rhs_transpose.reserve(rhs_rank); + rhs_transpose.insert(rhs_transpose.end(), + original_dnums.rhs_batch_dimensions().begin(), + original_dnums.rhs_batch_dimensions().end()); + rhs_transpose.insert(rhs_transpose.end(), + original_dnums.rhs_contracting_dimensions().begin(), + original_dnums.rhs_contracting_dimensions().end()); + rhs_transpose.insert(rhs_transpose.end(), rhs_non_contracting_dims.begin(), + rhs_non_contracting_dims.end()); + HloInstruction* transposed_rhs = + computation->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions(InversePermutation(rhs_transpose), + rhs_shape), + original_dot->mutable_operand(1), rhs_transpose)); + + std::vector rhs_reshape_dims = batch_dim_sizes; + rhs_reshape_dims.push_back(rhs_contracting_size); + rhs_reshape_dims.push_back(rhs_non_contracting_size); + // Reshape the contracting and non-contracting dimensions together. + HloInstruction* reshaped_rhs = + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims), + transposed_rhs)); + + std::vector dot_dims = batch_dim_sizes; + dot_dims.push_back(lhs_non_contracting_size); + dot_dims.push_back(rhs_non_contracting_size); + + DotDimensionNumbers dot_dnums; + for (int64 i = 0; i < num_batch_dims; ++i) { + dot_dnums.add_lhs_batch_dimensions(i); + dot_dnums.add_rhs_batch_dimensions(i); + } + dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1); + dot_dnums.add_rhs_contracting_dimensions(num_batch_dims); + + HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims), + reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config())); + + return computation->ReplaceInstruction( + original_dot, computation->AddInstruction(HloInstruction::CreateReshape( + original_dot->shape(), dot))); +} + } // namespace StatusOr DotDecomposer::Run(HloModule* module) { XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString()); - // Gather all batch Dot operations. - std::vector batch_dots; + // Gather all Non-canonical Dot operations. + std::vector non_canonical_dots; for (auto* computation : module->MakeNonfusionComputations()) { for (auto* instruction : computation->instructions()) { if (instruction->opcode() != HloOpcode::kDot) { continue; } const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); - if (dnums.lhs_batch_dimensions_size() > 0 && decompose_batch_dot_) { - batch_dots.push_back(instruction); + // A dot it not canonical if there are more than one contracting + // dimension. + if (dnums.lhs_contracting_dimensions_size() != 1) { + non_canonical_dots.push_back(instruction); + continue; + } + if (dnums.lhs_batch_dimensions().empty() && + dnums.lhs_contracting_dimensions().empty()) { + non_canonical_dots.push_back(instruction); + continue; + } + if (dnums.lhs_batch_dimensions().empty()) { + continue; + } + std::vector canonical_batch_dims( + dnums.lhs_batch_dimensions_size()); + absl::c_iota(canonical_batch_dims, 0); + if (!absl::c_equal(dnums.lhs_batch_dimensions(), canonical_batch_dims) || + !absl::c_equal(dnums.rhs_batch_dimensions(), canonical_batch_dims)) { + non_canonical_dots.push_back(instruction); } } } - // Decompose each batch Dot in 'batch_dots'. bool changed = false; - for (auto* dot : batch_dots) { - TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + for (auto* dot : non_canonical_dots) { + TF_RETURN_IF_ERROR(CanonicalizeDot(dot)); changed = true; } + + if (decompose_batch_dot_) { + std::vector batch_dots; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kDot) { + continue; + } + const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); + if (!dnums.lhs_batch_dimensions().empty()) { + batch_dots.push_back(instruction); + } + } + } + // Decompose each batch Dot in 'batch_dots'. + + for (auto* dot : batch_dots) { + TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + changed = true; + } + } XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 6d0472689bf48092ceef2e9792c1358687d707ec..de3b508064bfadd88396f050142e682de2294434 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/while_util.h" namespace xla { @@ -53,6 +54,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { Status HandleDot(HloInstruction* hlo) override; + Status HandleTuple(HloInstruction* hlo) override; + Status HandleTranspose(HloInstruction* hlo) override; Status HandleReshape(HloInstruction* hlo) override; @@ -77,6 +80,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { Status HandleElementwiseBinary(HloInstruction* hlo) override; + Status HandleWhile(HloInstruction* hlo) override; + private: using OperandDynamicDimensionFn = std::functionSetDynamicSize(hlo, index, dimension, dynamic_size); + return Status::OK(); + }); +} + Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, @@ -173,7 +188,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) { // Find out the new dynamic dimension after reduce. int64 dimensions_not_reduced_count = 0; - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + for (int i = 0; i < operand->shape().rank(); ++i) { if (dimension == i) { parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count, dynamic_size); @@ -207,7 +222,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { result_dim_mapping[i] = current_result_dims++; } - for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(0)->shape()); i++) { + for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) { if (!absl::c_linear_search( dimension_numbers.lhs_contracting_dimensions(), i)) { if (operand_index == 0) { @@ -217,7 +232,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { } } - for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(1)->shape()); i++) { + for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) { if (!absl::c_linear_search( dimension_numbers.rhs_contracting_dimensions(), i) && !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), @@ -383,6 +398,120 @@ Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter( }); } +Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) { + // While loop is handled by passing dynamic size hlos as parameters into the + // hlo while loop. This is done by replacing the original while with a new + // one. + // + // Before: + // + // op1 = ... + // op2 = ... + // op1_x = ... // dynamic dimension size of op1 + // while = while(op1, op2) + // + // + // After: + // + // op1 = ... + // op2 = ... + // op1_x = ... // dynamic dimension size of op1 + // while = while(op1, op2, op1_x) + // + // In the above graph, op_x is the bound of the dynamic dimension size of op1 + // and is wired into the while loop as new parameter. + // + // TODO(b/119843103): Once we implement dynamic bounds in XLA backend, dynamic + // bound can be propagated through native xla values instead of relying on + // additional parameter. + + // dynamic_size_to_operand_id_index_map keeps track of dynamic size operations + // to their operand ids in the new while loop. + absl::flat_hash_map + dynamic_size_to_operand_id_index_map; + + // operands_to_add collects dynamic sizes that need to be added to the while + // loop as parameters. Note that a dynamic size is ignored if it is already + // part of the parameter. i.e.: + // + // We don't do: + // + // op1 = ... + // op2 = ... + // op_x = ... // dynamic dimension size of both op1 and op2 + // while = while(op1, op2, op_x, op_x) // 4 parameters + // + // But we do: + // + // op1 = ... + // op2 = ... + // op_x = ... // dynamic dimension size of both op1 and op2 + // while = while(op1, op2, op_x) + // + // An alternative is to do this in a while loop CSE pass. + // + std::vector operands_to_add; + int64 operand_count = hlo->shape().tuple_shapes_size(); + TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + hlo, [&](HloInstruction*, ShapeIndex, int64, int64, + HloInstruction* dynamic_size) { + const HloInstruction* tuple_operand = hlo->operand(0); + for (int64 i = 0; i < tuple_operand->operand_count(); ++i) { + if (dynamic_size == tuple_operand->operand(i)) { + dynamic_size_to_operand_id_index_map[dynamic_size] = i; + return Status::OK(); + } + } + auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size); + if (iter == dynamic_size_to_operand_id_index_map.end()) { + operands_to_add.push_back(dynamic_size); + dynamic_size_to_operand_id_index_map[dynamic_size] = operand_count++; + } + return Status::OK(); + })); + + if (!operands_to_add.empty()) { + // Only replace the while loop if there are new parameters to add. + HloInstruction* old_tuple_operand = hlo->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + WhileUtil::MakeInstructionsLiveInResult result, + WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add)); + // WhileUtil creates a new while hlo and tuple. Update the dynamic size + // mapping for the newly created tuple. + HloInstruction* new_tuple_operand = + result.new_while_instr->mutable_operand(0); + parent_->CopyMapping(/*from=*/old_tuple_operand, /*to=*/new_tuple_operand); + hlo = result.new_while_instr; + } + + // We have replaced the while loop, now set the dynamic dimensions for the + // newly created while loop so that the hlos that consumes the while loop can + // see the dynamic dimensions. Also sets the dynamic parameter binding for + // running inference in the while loop. + DynamicParameterBinding binding_for_while; + TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + DynamicParameterBinding::DynamicParameter dynamic_parameter{ + operand_index, + {dynamic_size_to_operand_id_index_map[dynamic_size]}}; + DynamicParameterBinding::DynamicDimension dynamic_dimension{ + operand_index, index, dimension}; + TF_RETURN_IF_ERROR( + binding_for_while.Bind(dynamic_parameter, dynamic_dimension)); + parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); + return Status::OK(); + })); + + // Run inference in while body and condition. + TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run( + hlo->while_body(), binding_for_while, parent_)); + TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run( + hlo->while_condition(), binding_for_while, parent_)); + + return Status::OK(); +} + Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) { return param_bindings_.ForEachBinding( [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter, @@ -430,15 +559,43 @@ Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension( return Status::OK(); } +void DynamicDimensionInference::CopyMapping(HloInstruction* from, + HloInstruction* to) { + auto iter = per_hlo_dynamic_dimensions_.find(from); + if (iter != per_hlo_dynamic_dimensions_.end()) { + for (auto& dynamic_dimension : iter->second) { + HloInstruction* dynamic_size = + GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim); + SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim, + dynamic_size); + } + } +} + /* static */ StatusOr DynamicDimensionInference::Run( HloModule* module) { - VLOG(0) << "Param Config " << module->dynamic_parameter_binding().ToString(); + VLOG(2) << "Param Config " << module->dynamic_parameter_binding().ToString(); DynamicDimensionInference inference(module); TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions()); return inference; } +string DynamicDimensionInference::ToString() const { + std::vector pieces; + pieces.push_back("DynamicDimensionInference: "); + for (const auto& mapping : dynamic_mapping_) { + const DynamicDimension& dynamic_dimension = mapping.first; + pieces.push_back(absl::StrFormat( + " -- instruction %s at %s has dim %lld as dynamic" + " dimension, which is represented by instruction %s", + dynamic_dimension.inst->ToString(), dynamic_dimension.index.ToString(), + dynamic_dimension.dim, mapping.second->ToString())); + } + return absl::StrJoin(pieces, "\n"); +} + DynamicDimensionInference::DynamicDimensionInference(HloModule* module) : module_(module) {} diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h index 164d15bf111a92e3da957f609b54ee0662ef18b1..d0f2998328f3028ccbd5b33690a514371a03b5a1 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -88,6 +88,11 @@ class DynamicDimensionInference { iter.first->second.emplace(DynamicDimension{inst, index, dim}); } + // Copies the internal mapping from instruction `from` to instruction `to`. + // This is useful when an instruction is replaced by the other during the + // inferencing process. + void CopyMapping(HloInstruction* from, HloInstruction* to); + // AnalyzeDynamicDimensions starts the analysis of the dynamic dimensions in // module_. Status AnalyzeDynamicDimensions(); @@ -101,6 +106,8 @@ class DynamicDimensionInference { using DynamicMapping = absl::flat_hash_map; DynamicMapping dynamic_mapping_; + // A convenient mapping from an hlo to the set of dynamic dimensions that it + // holds. using PerHloDynamicDimensions = absl::flat_hash_map>; diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index ea9ebed45d99797ce4f80376ec3d0b758da3ca17..597cdf27c3318b3cf8bd5bb5f9b3239cf23a4c73 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -62,6 +62,17 @@ class DynamicDimensionInferenceTest : public HloTestBase { return module_->AddEmbeddedComputation(embedded_builder.Build()); } + HloComputation* GetGe() { + auto embedded_builder = HloComputation::Builder("ge"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGe, lhs, rhs)); + return module_->AddEmbeddedComputation(embedded_builder.Build()); + } + std::unique_ptr module_; std::unique_ptr inference_; const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {}); @@ -292,7 +303,8 @@ TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) { Window window; auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - zx_shape, a_param, b_param, /*feature_group_count=*/1, window, dnums, + zx_shape, a_param, b_param, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, HloTestBase::DefaultPrecisionConfig(2))); module_->AddEntryComputation(builder.Build()); @@ -433,6 +445,96 @@ TEST_F(DynamicDimensionInferenceTest, BroadcastTest) { EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 2), nullptr); } +TEST_F(DynamicDimensionInferenceTest, WhileTest) { + // Test the ability to trace into while loops. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + auto tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape}); + + // Body: + // + // Param + // | | + // GTE1 GTE2 + // | | + // ADD + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto gte_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, body_param, 0)); + auto gte_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, body_param, 1)); + auto add = body_builder.AddInstruction( + HloInstruction::CreateBinary(input_shape, HloOpcode::kAdd, gte_0, gte_1)); + body_builder.AddInstruction(HloInstruction::CreateTuple({add, add})); + + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + // Entry: + // + // Param + // | + // While + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, tuple_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {0}, 0})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 0})); + + // Test that dynamic dimension inference does the right thing. A lambda is + // used here since we want to test twice by running inference again + // (idempotency). + auto test_dynamic_dimension = [&]() { + HloInstruction* while_hlo = nullptr; + // The while hlo has been replaced, find the new one. + for (HloInstruction* inst : module_->entry_computation()->instructions()) { + if (inst->opcode() == HloOpcode::kWhile) { + while_hlo = inst; + } + } + ASSERT_NE(while_hlo, nullptr); + // The original while shape has 2 parameters. With dynamic size passed in + // as an extra parameter, the tuple should have 3 elements. + EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 3); + HloInstruction* add = nullptr; + for (HloInstruction* inst : while_hlo->while_body()->instructions()) { + if (inst->opcode() == HloOpcode::kAdd) { + add = inst; + } + } + EXPECT_NE(add, nullptr); + EXPECT_NE(inference_->GetDynamicSize(add, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {0}, 0), size_param); + EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {1}, 0), size_param); + }; + + TF_ASSERT_OK(RunInference()); + test_dynamic_dimension(); + TF_ASSERT_OK(RunInference()); + test_dynamic_dimension(); +} + TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) { // Test the ability to trace reduce window batch dimensions. auto builder = HloComputation::Builder(TestName()); @@ -486,7 +588,7 @@ TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) { // Test the ability to trace select and scatter batch dimensions. auto builder = HloComputation::Builder(TestName()); auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); - auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + auto source_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); Window window; // First dimension is unchanged. @@ -513,22 +615,26 @@ TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) { /*parameter_number=*/0, input_shape, "A")); auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, scalar_shape_, "size_param")); + auto* source = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, source_shape, "B")); auto init = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); - auto* reduce_window = - builder.AddInstruction(HloInstruction::CreateReduceWindow( - output_shape, a_param, init, window, GetAdd())); + auto* sns = builder.AddInstruction(HloInstruction::CreateSelectAndScatter( + input_shape, a_param, GetGe(), window, source, init, GetAdd())); module_->AddEntryComputation(builder.Build()); TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 0})); + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{2, {}, 0})); TF_ASSERT_OK(RunInference()); - EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param); + EXPECT_EQ(inference_->GetDynamicSize(sns, {}, 0), size_param); } } // namespace diff --git a/tensorflow/compiler/xla/service/dynamic_index_splitter.cc b/tensorflow/compiler/xla/service/dynamic_index_splitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..e34adfd2d2bbb7214cfa2da28291b133538845e5 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_index_splitter.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +StatusOr DynamicIndexSplitter::Run(HloModule* module) { + bool changed = false; + + std::vector computations = + module->MakeNonfusionComputations(); + for (HloComputation* computation : computations) { + for (HloInstruction* dynamic_op : computation->MakeInstructionPostOrder()) { + switch (dynamic_op->opcode()) { + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + break; + default: + continue; + } + auto parent = dynamic_op->parent(); + bool is_update = dynamic_op->opcode() == HloOpcode::kDynamicUpdateSlice; + int64 num_indices = dynamic_op->operand(0)->shape().rank(); + + if (num_indices == 0) { + // If the operand rank is 0, directly replace R0 DS/DUS with the + // operand (for DS) or update (for DUS). + if (is_update) { + TF_CHECK_OK(parent->ReplaceInstruction( + dynamic_op, dynamic_op->mutable_operand(1))); + } else { + TF_CHECK_OK(parent->ReplaceInstruction( + dynamic_op, dynamic_op->mutable_operand(0))); + } + changed = true; + continue; + } + + int64 index_operand_number = Cast(dynamic_op) + ->first_index_operand_number(); + auto index_operand = dynamic_op->mutable_operand(index_operand_number); + if (ShapeUtil::IsScalar(index_operand->shape())) { + // This DS/DUS already uses scalar indices. + continue; + } + TF_RET_CHECK(index_operand->shape().rank() == 1); + auto index_element_type = index_operand->shape().element_type(); + std::vector index_array; + for (int64 dim = 0; dim < num_indices; ++dim) { + auto slice = parent->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(index_element_type, {1}), index_operand, {dim}, + {dim + 1}, {1})); + auto bitcast = parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(index_element_type, {}), slice)); + index_array.push_back(bitcast); + } + auto new_dynamic_op = + is_update + ? HloInstruction::CreateDynamicUpdateSlice( + dynamic_op->shape(), dynamic_op->mutable_operand(0), + dynamic_op->mutable_operand(1), absl::MakeSpan(index_array)) + : HloInstruction::CreateDynamicSlice( + dynamic_op->shape(), dynamic_op->mutable_operand(0), + absl::MakeSpan(index_array), + dynamic_op->dynamic_slice_sizes()); + TF_CHECK_OK(parent->ReplaceWithNewInstruction(dynamic_op, + std::move(new_dynamic_op))); + changed = true; + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_index_splitter.h b/tensorflow/compiler/xla/service/dynamic_index_splitter.h new file mode 100644 index 0000000000000000000000000000000000000000..3c12e3a4af287ad2272a08ba54cd99c2cad9d451 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_index_splitter.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Convert R1 index operands to DynamicSlice and DynamicUpdateSlice ops into +// separate scalars. +class DynamicIndexSplitter : public HloModulePass { + public: + DynamicIndexSplitter() = default; + absl::string_view name() const override { return "dynamic-index-splitter"; } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc b/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..98029d1faff7d669730f6b66e38fcefece70f0eb --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; +class DynamicIndexSplitterTest : public HloTestBase {}; + +TEST_F(DynamicIndexSplitterTest, DynamicSlice) { + const char* const kDynamicSlice = R"( + HloModule DynamicSlice_module + + ENTRY entry (operand: s32[4,5,6], indices: s32[3]) -> s32[1,1,1] { + operand = s32[4,5,6] parameter(0) + indices = s32[3] parameter(1) + ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, indices), dynamic_slice_sizes={1,1,1} + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kDynamicSlice, config)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + DynamicIndexSplitter().Run(module.get())); + EXPECT_TRUE(changed); + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice(op::Parameter(0), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))))); + + for (int i = 0; i < 3; ++i) { + const HloInstruction* slice = module->entry_computation() + ->root_instruction() + ->operand(i + 1) + ->operand(0); + EXPECT_EQ(slice->slice_starts(0), i); + EXPECT_EQ(slice->slice_limits(0), i + 1); + } +} + +TEST_F(DynamicIndexSplitterTest, DynamicUpdateSlice) { + const char* const kDynamicUpdateSlice = R"( + HloModule DynamicUpdatedSlice_module + + ENTRY entry (operand: s32[4,5,6], indices: s32[3], update: s32[1,1,1]) -> s32[4,5,6] { + operand = s32[4,5,6] parameter(0) + indices = s32[3] parameter(1) + update = s32[1,1,1] parameter(2) + ROOT dynamic-update-slice = s32[4,5,6] dynamic-update-slice(operand, update, indices) + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kDynamicUpdateSlice, config)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + DynamicIndexSplitter().Run(module.get())); + EXPECT_TRUE(changed); + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::DynamicUpdateSlice(op::Parameter(0), op::Parameter(2), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))))); + + for (int i = 0; i < 3; ++i) { + const HloInstruction* slice = module->entry_computation() + ->root_instruction() + ->operand(i + 2) + ->operand(0); + EXPECT_EQ(slice->slice_starts(0), i); + EXPECT_EQ(slice->slice_limits(0), i + 1); + } +} + +TEST_F(DynamicIndexSplitterTest, AlreadyScalar) { + const char* const kDynamicSlice = R"( + HloModule DynamicSlice_module + + ENTRY entry (operand: s32[4,5,6], index.0: s32[], index.1: s32[], index.2: s32[]) -> s32[1,1,1] { + operand = s32[4,5,6] parameter(0) + index.0 = s32[] parameter(1) + index.1 = s32[] parameter(2) + index.2 = s32[] parameter(3) + ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, index.0, index.1, index.2), dynamic_slice_sizes={1,1,1} + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kDynamicSlice, config)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + DynamicIndexSplitter().Run(module.get())); + EXPECT_FALSE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice(op::Parameter(0), op::Parameter(1), + op::Parameter(2), op::Parameter(3))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc new file mode 100644 index 0000000000000000000000000000000000000000..4db280f817141bd52e3a5b9564600a618f81aeac --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -0,0 +1,161 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/dynamic_padder.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +// ChooseIdentityValue looks at the instruction and returns a identity value +// which, when padded, doesn't change the result of the instruction. +// +// nullopt is returned if padding doesn't need to be reset. +StatusOr ChooseIdentityValue(HloInstruction* inst) { + HloComputation* comp = inst->parent(); + // Padding on elementwise operation doesn't affect the result of the effective + // data. + if (inst->IsElementwise()) { + return nullptr; + } + + switch (inst->opcode()) { + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: { + // Because of the way we do reduce, we already require the `init` operand + // of hlo reduce instruction to be identity value. Here we reuse the + // operand. + return inst->mutable_operand(1); + } + + case HloOpcode::kConvolution: + case HloOpcode::kDot: { + // Use 0 as padding value for convolution and dot. + PrimitiveType ptype = inst->shape().element_type(); + return comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(ptype))); + } + + case HloOpcode::kPad: { + return inst->mutable_operand(1); + } + case HloOpcode::kParameter: + case HloOpcode::kGetDimensionSize: + case HloOpcode::kReshape: + case HloOpcode::kTuple: + case HloOpcode::kAllReduce: + case HloOpcode::kBroadcast: + return nullptr; + default: + return UnimplementedStrCat("Unimplimented padding for instruction: ", + inst->ToString()); + } +} + +} // namespace + +StatusOr DynamicPadder::Run(HloModule* module) { + bool changed = false; + VLOG(2) << "Pre DynamicPadder HLO:"; + XLA_VLOG_LINES(2, module->ToString()); + TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(module)); + + for (HloComputation* computation : module->computations()) { + for (HloInstruction* inst : computation->instructions()) { + for (int64 operand_num = 0; operand_num < inst->operand_count(); + ++operand_num) { + HloInstruction* operand = inst->mutable_operand(operand_num); + if (!operand->shape().IsArray()) { + continue; + } + for (int64 dim = 0; dim < operand->shape().rank(); ++dim) { + HloInstruction* dynamic_size = + dynamic_dimension_inference.GetDynamicSize(operand, {}, dim); + if (dynamic_size == nullptr) { + continue; + } + VLOG(1) << "Has dynamic dimension of operand" << operand_num << " @" + << dim; + TF_ASSIGN_OR_RETURN(HloInstruction * identity_value, + ChooseIdentityValue(inst)); + if (identity_value == nullptr) { + continue; + } + + // For each dimension, first generates a mask representing the + // effective area of data and padded area of data using iota and + // dynamic_size. For example, given a dimension of 7 elements and 5 + // effective elements: + // + // iota = [0, 1, 2, 3, 4, 5, 6] + // broadcast_dynamic_size = [5, 5, 5, 5, 5, 5, 5] + // mask = lt(iota, broadcast_dynamic_size) = [t, t, t, t, t, f, f] + // + // Once the mask is generated, the input data is then padded using the + // mask and pad value. + // + const Shape mask_shape = + ShapeUtil::ChangeElementType(operand->shape(), xla::U32); + const Shape pred_shape = + ShapeUtil::ChangeElementType(operand->shape(), xla::PRED); + HloInstruction* iota = computation->AddInstruction( + HloInstruction::CreateIota(mask_shape, dim)); + + HloInstruction* broadcasted_effective_size = + computation->AddInstruction(HloInstruction::CreateBroadcast( + mask_shape, dynamic_size, {})); + HloInstruction* pred = computation->AddInstruction( + HloInstruction::CreateBinary(pred_shape, HloOpcode::kLt, iota, + broadcasted_effective_size)); + + HloInstruction* broadcasted_identity_value = + computation->AddInstruction(HloInstruction::CreateBroadcast( + operand->shape(), identity_value, {})); + HloInstruction* padded = + computation->AddInstruction(HloInstruction::CreateTernary( + operand->shape(), HloOpcode::kSelect, pred, operand, + broadcasted_identity_value)); + TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, padded)); + operand = inst->mutable_operand(operand_num); + changed = true; + } + } + } + } + HloDCE dce; + TF_ASSIGN_OR_RETURN(changed, dce.Run(module)); + VLOG(2) << "Post DynamicPadder HLO:"; + XLA_VLOG_LINES(2, module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.h b/tensorflow/compiler/xla/service/dynamic_padder.h new file mode 100644 index 0000000000000000000000000000000000000000..509269f7f56746fa5516ad917a04221587c6dcca --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_padder.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PADDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PADDER_H_ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// With bounded shapes, only part of the shape contains effective data and the +// rest contains padded data, whose value can be anything depending on the +// source of the data. When a bounded shape is directly consumed by an +// instruction that collapses dimensions (reduce for example), the padding data +// would affect result of the instruction. +// +// DynamicPadder uses DynamicDimensionInference to detect bounded shapes in a +// hlo module, it then inserts certain instructions to reset the padding into an +// identity value so that in doesn't affect the result of subsequent +// instruction. For example, it'd reset the padding to 0 before a bounded shape +// is consumed by a reduce-sum. +class DynamicPadder : public HloModulePass { + public: + absl::string_view name() const override { return "dynamic_padder"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PADDER_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..55a11286e4596d87c330315322cae704fc5cd707 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -0,0 +1,152 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_padder.h" + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class DynamicPadderTest : public HloTestBase { + protected: + DynamicPadderTest() : HloTestBase() { module_ = CreateNewVerifiedModule(); } + + StatusOr RunPadder() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before padder"); + + DynamicPadder padder; + + return padder.Run(module_.get()); + } + + void ExpectPadded(const HloInstruction* inst) { + EXPECT_THAT(inst, + op::Select(op::Lt(op::Iota(), op::Broadcast(op::Parameter())), + ::testing::_, op::Broadcast())); + } + + HloComputation* GetScalarAddComputation() { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + return module_->AddEmbeddedComputation(embedded_builder.Build()); + } + + std::unique_ptr module_; + const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {}); +}; + +TEST_F(DynamicPadderTest, ReduceTest) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, negate, init, {0, 2}, GetScalarAddComputation())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunPadder().status()); + + ExpectPadded(reduce->operand(0)); +} + +TEST_F(DynamicPadderTest, ConvolutionTest) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0); + + dnums.set_kernel_input_feature_dimension(0); + dnums.set_kernel_output_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(1); + dnums.set_output_feature_dimension(0); + + Window window; + + auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + zx_shape, a_param, b_param, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for non-contracting dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunPadder().status()); + + ExpectPadded(conv->operand(0)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc index c8bfc8905064bcd7b68fe259fbcc1546ff083dbd..7f0ae692f7414dbdcccda8b287c9059bcf920df1 100644 --- a/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc @@ -29,7 +29,8 @@ Status DynamicParameterBinding::Bind( } absl::optional -DynamicParameterBinding::GetBinding(const DynamicDimension& dynamic_dimension) { +DynamicParameterBinding::GetBinding( + const DynamicDimension& dynamic_dimension) const { auto param_iter = bindings_.find(dynamic_dimension); if (param_iter == bindings_.end()) { return absl::nullopt; @@ -70,7 +71,7 @@ StatusOr DynamicParameterBinding::CreateFromProto( int64 target_param_num = binding.target_param_num(); ShapeIndex target_param_index(binding.target_param_index().begin(), binding.target_param_index().end()); - int64 target_dim_num = binding.target_param_num(); + int64 target_dim_num = binding.target_param_dim_num(); TF_RETURN_IF_ERROR( result.Bind(DynamicParameter{dynamic_param_num, dynamic_param_index}, @@ -111,7 +112,8 @@ Status DynamicParameterBinding::Verify(const HloModule& module) const { return ForEachBinding([&](const DynamicParameter& dynamic_parameter, const DynamicDimension& dynamic_dimension) -> Status { - TF_RET_CHECK(dynamic_parameter.parameter_num < entry->num_parameters()); + TF_RET_CHECK(dynamic_parameter.parameter_num >= 0 && + dynamic_parameter.parameter_num < entry->num_parameters()); TF_RET_CHECK(dynamic_dimension.parameter_num < entry->num_parameters()); TF_RET_CHECK(ShapeUtil::IndexIsValid( entry->parameter_instruction(dynamic_parameter.parameter_num)->shape(), @@ -121,10 +123,11 @@ Status DynamicParameterBinding::Verify(const HloModule& module) const { dynamic_dimension.parameter_index)); TF_RET_CHECK( dynamic_dimension.dimension < - ShapeUtil::Rank(ShapeUtil::GetSubshape( + ShapeUtil::GetSubshape( entry->parameter_instruction(dynamic_dimension.parameter_num) ->shape(), - dynamic_dimension.parameter_index))); + dynamic_dimension.parameter_index) + .rank()); return Status::OK(); }); } diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.h b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h index dd474d8eed1b2c30ddb8f624a864198c74eacaba..57af2c43d3c65f7340e6a9f04e5abbf052ebceea 100644 --- a/tensorflow/compiler/xla/service/dynamic_parameter_binding.h +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h @@ -89,7 +89,7 @@ class DynamicParameterBinding { // // Returns nullopt if the binding is not set. absl::optional GetBinding( - const DynamicDimension& dynamic_dimension); + const DynamicDimension& dynamic_dimension) const; using BindingFn = std::functionToProto(); + TF_ASSERT_OK_AND_ASSIGN(*binding, + DynamicParameterBinding::CreateFromProto(proto)); + } +}; TEST_F(DynamicParameterBindingTest, SimpleBinding) { // 'b' is a dynamic shape; 'a' represents the real size of b's first @@ -56,15 +64,20 @@ ENTRY main { binding.Bind(DynamicParameterBinding::DynamicParameter{0, {}}, DynamicParameterBinding::DynamicDimension{1, {}, 0})); - absl::optional param = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1, - /*parameter_index=*/{}, - /*dimension=*/0}); - EXPECT_TRUE(param); - EXPECT_EQ(param->parameter_num, 0); - EXPECT_EQ(param->parameter_index, ShapeIndex({})); - TF_EXPECT_OK(binding.Verify(*module)); + auto test = [&](const DynamicParameterBinding& binding) { + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1, + /*parameter_index=*/{}, + /*dimension=*/0}); + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({})); + TF_EXPECT_OK(binding.Verify(*module)); + }; + test(binding); + SerializeAndDeserialize(&binding); + test(binding); } TEST_F(DynamicParameterBindingTest, TupleBinding) { @@ -89,16 +102,21 @@ ENTRY main { binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, DynamicParameterBinding::DynamicDimension{0, {1}, 0})); - absl::optional param = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, - /*parameter_index=*/{1}, - /*dimension=*/0}); - - EXPECT_TRUE(param); - EXPECT_EQ(param->parameter_num, 0); - EXPECT_EQ(param->parameter_index, ShapeIndex({0})); - TF_EXPECT_OK(binding.Verify(*module)); + auto test = [&](const DynamicParameterBinding& binding) { + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + TF_EXPECT_OK(binding.Verify(*module)); + }; + test(binding); + SerializeAndDeserialize(&binding); + test(binding); } TEST_F(DynamicParameterBindingTest, TupleBindingWithMultiDimension) { @@ -127,26 +145,35 @@ ENTRY main { binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, DynamicParameterBinding::DynamicDimension{0, {1}, 1})); - absl::optional param = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, - /*parameter_index=*/{1}, - /*dimension=*/0}); - - EXPECT_TRUE(param); - EXPECT_EQ(param->parameter_num, 0); - EXPECT_EQ(param->parameter_index, ShapeIndex({0})); - - absl::optional param2 = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, - /*parameter_index=*/{1}, - /*dimension=*/0}); - EXPECT_TRUE(param2); - EXPECT_EQ(param2->parameter_num, 0); - EXPECT_EQ(param2->parameter_index, ShapeIndex({0})); - - TF_EXPECT_OK(binding.Verify(*module)); + auto test = [&](const DynamicParameterBinding& binding) { + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + + absl::optional param2 = + + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + EXPECT_TRUE(param2); + EXPECT_EQ(param2->parameter_num, 0); + EXPECT_EQ(param2->parameter_index, ShapeIndex({0})); + TF_EXPECT_OK(binding.Verify(*module)); + }; + + test(binding); + + SerializeAndDeserialize(&binding); + + // Test the binding again after deserialization. + test(binding); } } // namespace diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 6f1f95f2e9082649b6ca9cc0da5c238e15b77c10..a62a743802456d0239438a12884f5a594aa05798 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -423,6 +423,10 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return EmitSin(op->shape().element_type(), operand_value); case HloOpcode::kTanh: return EmitTanh(op->shape().element_type(), operand_value); + case HloOpcode::kSqrt: + return EmitSqrt(op->shape().element_type(), operand_value); + case HloOpcode::kRsqrt: + return EmitRsqrt(op->shape().element_type(), operand_value); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {operand_value}, @@ -440,14 +444,16 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( {operand_value}, {operand_value->getType()}, b_); case HloOpcode::kSign: { - // TODO(b/32151903): Ensure consistent sign behavior for -0.0. auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = FCmpOEQ(operand_value, zero); - auto olt = FCmpOLT(operand_value, zero); - return Select(oeq, zero, - Select(olt, llvm::ConstantFP::get(type, -1.0), - llvm::ConstantFP::get(type, 1.0))); + auto ne0_i1 = FCmpONE(operand_value, zero); + auto ne0_float = UIToFP(ne0_i1, type); + llvm::Value* result = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::copysign, {ne0_float, operand_value}, + {operand_value->getType()}, b_); + auto is_nan = FCmpUNO(operand_value, operand_value); + result = Select(is_nan, operand_value, result); + return result; } case HloOpcode::kIsFinite: { // abs(x) o!= inf, this works because the comparison returns false if @@ -653,6 +659,20 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs), FDiv(EmitExtractImag(operand_value), cplx_abs))); } + case HloOpcode::kSqrt: { + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + auto c = llvm::ConstantFP::get(a->getType(), 0.5); + auto d = llvm::ConstantFP::get(b->getType(), 0.0); + return EmitComplexPower(op, a, b, c, d); + } + case HloOpcode::kRsqrt: { + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + auto c = llvm::ConstantFP::get(a->getType(), -0.5); + auto d = llvm::ConstantFP::get(b->getType(), 0.0); + return EmitComplexPower(op, a, b, c, d); + } case HloOpcode::kNegate: return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)), FNeg(EmitExtractImag(operand_value))); @@ -736,6 +756,43 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( } } +// (a+bi)^(c+di) = +// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), +// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) +StatusOr ElementalIrEmitter::EmitComplexPower( + const HloInstruction* op, llvm::Value* a, llvm::Value* b, llvm::Value* c, + llvm::Value* d) { + PrimitiveType component_type = + primitive_util::ComplexComponentType(op->shape().element_type()); + auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b)); + auto zero = llvm::ConstantFP::get(a->getType(), 0); + auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); + auto one = llvm::ConstantFP::get(a->getType(), 1); + auto half_c = FMul(one_half, c); + + TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, + EmitPow(component_type, aa_p_bb, half_c)); + + auto neg_d = FNeg(d); + TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); + auto neg_d_arg_lhs = FMul(neg_d, arg_lhs); + TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, + EmitExp(component_type, neg_d_arg_lhs)); + auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); + auto half_d = FMul(one_half, d); + auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb)); + TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); + TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); + // 0^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see + // Branch Cuts for Complex Elementary Functions or Much Ado About + // Nothing's Sign Bit, W. Kahan, Section 10. + return Select( + And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)), + EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero), + EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q))); +} + StatusOr ElementalIrEmitter::EmitComplexBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { switch (op->opcode()) { @@ -802,33 +859,11 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( EmitExtractImag(rhs_value), b_)); case HloOpcode::kPower: { - // (a+bi)^(c+di) = - // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), - // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) - PrimitiveType component_type = - primitive_util::ComplexComponentType(op->shape().element_type()); auto a = EmitExtractReal(lhs_value); auto b = EmitExtractImag(lhs_value); auto c = EmitExtractReal(rhs_value); auto d = EmitExtractImag(rhs_value); - auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b)); - auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); - auto half_c = FMul(one_half, c); - - TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, - EmitPow(component_type, aa_p_bb, half_c)); - auto neg_d = FNeg(d); - TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); - auto neg_d_arg_lhs = FMul(neg_d, arg_lhs); - TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, - EmitExp(component_type, neg_d_arg_lhs)); - auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); - TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); - auto half_d = FMul(one_half, d); - auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb)); - TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); - TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); - return EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)); + return EmitComplexPower(op, a, b, c, d); } default: return Unimplemented("binary complex op '%s'", @@ -846,6 +881,9 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_); } +// TODO(b/123355973): We have an implementation of erfinv in math.cc. We +// shouldn't have two implementations, especially since this one isn't testable +// (it's only observable via a normally-distributed RNG). StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Value* x) { if (prim_type != F16 && prim_type != F32 && prim_type != F64) { @@ -1038,6 +1076,18 @@ StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, return Select(x_is_small, for_small_x, for_large_x); } +StatusOr ElementalIrEmitter::EmitSqrt(PrimitiveType prim_type, + llvm::Value* value) { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value}, + {value->getType()}, b_); +} + +StatusOr ElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type, + llvm::Value* value) { + TF_ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, value)); + return FDiv(llvm::ConstantFP::get(sqrt->getType(), 1.0), sqrt); +} + StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, @@ -1327,9 +1377,9 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( // If implicit broadcast is needed, the source dimensions that are broadcast // have index 0. - CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape())); + CHECK_EQ(operand_shape.rank(), hlo.shape().rank()); llvm_ir::IrArray::Index source_index(target_index.GetType()); - for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) { + for (int64 i = 0; i < hlo.shape().rank(); ++i) { if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) { source_index.push_back(target_index[i]); } else { @@ -1353,26 +1403,69 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( llvm_ir::PrimitiveTypeToIrType(elem_prim_ty, module_); llvm::Type* raw_value_ty = raw_value->getType(); - // Convert raw integer to float in range [0, 1) if the element is a float. + // If we're generating a floating-point value, convert the raw integer R (i.e. + // `raw_value`) to a float in the range [0, 1). + // + // The basic approach is to choose a significand and exponent such that the + // significand is uniformly distributed and the exponent is distributed, well, + // exponentially (it's more likely to be close to 0 than far from 0). + // + // An easy way to do this is to say that the significand is the first S bits + // of R, and the exponent is determined by the number of trailing zeroes in R, + // exp = 2^-(cttz(R) + 1). (+1 because the largest exponent should be -1; + // this way the largest value we can return is 1.999... * 2^-1 = 1-ε.) + // + // This results in a small bias. Namely, if R has enough trailing zeroes, the + // significand and exponent will "overlap". As a concrete example, consider + // + // 20 X's 12 zeroes + // R = 0bXXXXXXXXXXXXXXXXXXXX000000000000 + // + // Here the exponent is 2^-13 because R has 12 trailing zeroes. The + // significand is made up of the first 23 most-significant bits of R, which we + // observe contain 3 zeroes. This is biased because any random value with + // exponent 2^-12 will have a significand which ends in `000`. + // + // For f32s, this problem occurs only when there are more than 32-23 = 9 + // trailing zeros, which happens with probability 0.5^10 = ~0.1%. Moreover the + // probability of a large bias (i.e. many trailing 0s in the significand) is + // exponentially low. So we deem this acceptable. llvm::Value* elem_value = raw_value; if (elem_ir_ty->isFloatingPointTy()) { - unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits(); - CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64); - // Perform the division using the float type with the same number of bits - // as the raw value to avoid overflow. - if (raw_value_size_in_bits == 32) { - elem_value = UIToFP(elem_value, b_->getFloatTy()); - elem_value = FDiv(elem_value, - llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); - } else { - elem_value = UIToFP(elem_value, b_->getDoubleTy()); - elem_value = FDiv( - elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64))); - } - - if (elem_ir_ty != elem_value->getType()) { - elem_value = FPTrunc(elem_value, elem_ir_ty); - } + const auto& dest_flt_semantics = elem_ir_ty->getFltSemantics(); + const int bits = raw_value_ty->getPrimitiveSizeInBits(); + CHECK_GE(bits, llvm::APFloat::semanticsSizeInBits(dest_flt_semantics)); + + // Subtract 1 because semanticsPrecision includes the "hidden bit", i.e. the + // implicit "1." at the beginning of the significand. + const int significand_bits = + llvm::APFloat::semanticsPrecision(dest_flt_semantics) - 1; + + llvm::Value* cttz = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::cttz, {raw_value, /*is_zero_undef=*/b_->getFalse()}, + {raw_value->getType()}, b_); + llvm::Value* significand = LShr(raw_value, bits - significand_bits); + + // Exponent bias is -127 for f32, meaning that if the exponent is E and the + // significand is S, then the value of the number is 2^(E - 127) * (1.S). + // + // We want cttz == 0 to correspond to 2^-1, so our exponent is computed as + // E = 126 - cttz. + // + // For f64, this is all the same, except the bias is -1023. + // + // In IEEE floating point, the absolute value of the exponent bias equals + // the value of the largest possible exponent. + const int bias = -llvm::APFloat::semanticsMaxExponent(dest_flt_semantics); + llvm::Value* exponent = + Sub(llvm::ConstantInt::get(cttz->getType(), -bias - 1), cttz); + + // Now just slot everything into place! The `Trunc` is here because + // raw_value may be larger than our float destination. + elem_value = + BitCast(Trunc(Or(Shl(exponent, significand_bits), significand), + b_->getIntNTy(elem_ir_ty->getPrimitiveSizeInBits())), + elem_ir_ty); } // Convert the value for the requested distribution. @@ -1750,7 +1843,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( const llvm_ir::IrArray::Index& index) { // Emit IR to read dynamic start indices from hlo->operand(1). const HloInstruction* input_hlo = hlo->operand(0); - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + const int64 rank = input_hlo->shape().rank(); // Use the same index type for all tensor accesses in the same kernel. llvm::Type* index_type = index.GetType(); llvm_ir::IrArray::Index slice_start_index(index_type, rank); @@ -1758,9 +1851,10 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( auto index_typed_const = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); }; - llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); - TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, - operand_to_generator.at(hlo->operand(1))(dim_index)); + llvm_ir::IrArray::Index zero_index(index_type); + TF_ASSIGN_OR_RETURN( + llvm::Value * start_index_value, + operand_to_generator.at(hlo->operand(1 + i))(zero_index)); // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) @@ -1893,7 +1987,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* update_hlo = hlo->operand(1); const HloInstruction* start_hlo = hlo->operand(2); // Calculate slice start/end indices. - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + const int64 rank = input_hlo->shape().rank(); llvm_ir::IrArray::Index slice_start_index(index.GetType(), rank); llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank); // Slice intersection gathers (ANDs) conditions on all ranks for which @@ -1905,9 +1999,11 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( auto index_typed_const = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); }; - llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); - TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, - operand_to_generator.at(start_hlo)(dim_index)); + + llvm_ir::IrArray::Index zero_index(index_type); + TF_ASSIGN_OR_RETURN( + llvm::Value * start_index_value, + operand_to_generator.at(hlo->operand(2 + i))(zero_index)); // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) @@ -2128,8 +2224,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kReal: + case HloOpcode::kRsqrt: case HloOpcode::kSign: case HloOpcode::kSin: + case HloOpcode::kSqrt: case HloOpcode::kTanh: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { @@ -2225,7 +2323,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( auto* iota = Cast(hlo); PrimitiveType element_type = iota->shape().element_type(); IrArray::Index elem_index = - ShapeUtil::Rank(iota->shape()) > 1 + iota->shape().rank() > 1 ? target_index.SourceIndexOfBroadcast( iota->shape(), ShapeUtil::MakeShapeWithDescendingLayout( diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index d3e2acaabd4f602171def70ccd3d4fd5adce0d0d..819465f1e5d633a0652b09005a3d9a08874759bd 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -119,6 +119,12 @@ class ElementalIrEmitter : public IrBuilderMixin { virtual StatusOr EmitLog(PrimitiveType prim_type, llvm::Value* value); + virtual StatusOr EmitSqrt(PrimitiveType prim_type, + llvm::Value* value); + + virtual StatusOr EmitRsqrt(PrimitiveType prim_type, + llvm::Value* value); + virtual StatusOr EmitLog1p(PrimitiveType prim_type, llvm::Value* value); @@ -211,13 +217,21 @@ class ElementalIrEmitter : public IrBuilderMixin { const HloModuleConfig& hlo_module_config_; private: + // Computes the complex power function, returns (a + i*b)^(c + i*d). + StatusOr EmitComplexPower(const HloInstruction* op, + llvm::Value* a, llvm::Value* b, + llvm::Value* c, llvm::Value* d); + // Returns a ElementGenerator for an RNG HloInstruction using the Philox // random number generation algorithm. llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator); + // Converts the raw value generated by a random number generation algorithm // to the distribution requested by the RNG HloInstruction. + // + // Precondition: raw_value has at least as many bits as hlo's element type. StatusOr ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 10b8c01ff1383658fcfb2271c177ba54347f985a..1518d83083b3b0ce876da9344c483a23cd5b073c 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" - namespace xla { StatusOr> Executable::ExecuteOnStreams( @@ -173,11 +172,13 @@ Status Executable::DumpHloSnapshot() { } filename = SanitizeFileName(std::move(filename)); string file_path = tensorflow::io::JoinPath(directory_path, filename); - string result; - TF_RET_CHECK( - tensorflow::SerializeToStringDeterministic(hlo_session, &result)); - return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path, - result); + const size_t size = hlo_session.ByteSizeLong(); + auto serialized = absl::make_unique(size); + TF_RET_CHECK(tensorflow::SerializeToBufferDeterministic( + hlo_session, serialized.get(), size)); + return tensorflow::WriteStringToFile( + tensorflow::Env::Default(), file_path, + absl::string_view(serialized.get(), size)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 01cef499665c050d4453382289168276028e1d26..a58ac39dffad56315308f784b08e6b6087b8e30a 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -153,10 +153,9 @@ static StatusOr> GatherLoopBody( dim_numbers.index_vector_dim() == gather.operand(1)->shape().dimensions_size()); - TF_ASSIGN_OR_RETURN( - HloInstruction * induction_var_as_vector, + HloInstruction* induction_var_as_vector = MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, - /*result_shape_bounds=*/{1})); + /*result_shape_bounds=*/{1}); HloInstruction* index_vector; @@ -222,7 +221,7 @@ static StatusOr> GatherLoopBody( {operand, start_indices, updated_accumulator}}; } -static StatusOr CreateGatherLoopAccumulatorInitValue( +static HloInstruction* CreateGatherLoopAccumulatorInitValue( HloComputation* computation, PrimitiveType element_type, absl::Span slice_sizes, int64 gather_loop_trip_count, const GatherDimensionNumbers& dim_numbers) { @@ -297,7 +296,7 @@ static StatusOr PermuteBatchAndOffsetDims( // [3,1] out of operand into an accumulator of shape [4,3,1]. We then // reshape this result to [2,2,3] and finally transpose it to [2,3,2]. -StatusOr GatherExpander::ExpandGather( +StatusOr GatherExpander::ExpandInstruction( HloInstruction* gather_instr) { CHECK(!ShapeUtil::IsZeroElementArray(gather_instr->shape())); @@ -332,12 +331,10 @@ StatusOr GatherExpander::ExpandGather( CHECK_EQ(gather_loop_trip_count, canonical_start_indices->shape().dimensions(0)); - TF_ASSIGN_OR_RETURN( - HloInstruction * accumulator_init, - CreateGatherLoopAccumulatorInitValue( - computation, output_shape.element_type(), - gather_instr->gather_slice_sizes(), gather_loop_trip_count, - gather_instr->gather_dimension_numbers())); + HloInstruction* accumulator_init = CreateGatherLoopAccumulatorInitValue( + computation, output_shape.element_type(), + gather_instr->gather_slice_sizes(), gather_loop_trip_count, + gather_instr->gather_dimension_numbers()); StatusOr> gather_loop_result_or_error = WhileUtil::MakeCountedLoop( @@ -364,25 +361,11 @@ StatusOr GatherExpander::ExpandGather( output_rank); } -StatusOr GatherExpander::Run(HloModule* module) { - auto is_nontrivial_gather = [](HloInstruction* inst) { - return inst->opcode() == HloOpcode::kGather && - // Avoid expanding gather ops that produce zero sized tensors, - // instead punt these to ZeroSizedHloElimination. - !ShapeUtil::IsZeroElementArray(inst->shape()); - }; - - std::vector gather_instrs; - for (HloComputation* computation : module->MakeNonfusionComputations()) { - absl::c_copy_if(computation->instructions(), - std::back_inserter(gather_instrs), is_nontrivial_gather); - } - - for (HloInstruction* inst : gather_instrs) { - TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandGather(inst)); - TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); - } - - return !gather_instrs.empty(); +bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) { + return inst->opcode() == HloOpcode::kGather && + // Avoid expanding gather ops that produce zero sized tensors, + // instead punt these to ZeroSizedHloElimination. + !ShapeUtil::IsZeroElementArray(inst->shape()); } + } // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index 8af9c6b71fbc391bf7c0e9809e979b65135a6df3..5625a37cb46ca5b70f69d86bc424f6512bfb293f 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -16,20 +16,22 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" namespace xla { // This pass rewrites gather operations into (roughly) while loops of dynamic // slices. This lets backends that don't support gather directly to // nevertheless have a minimum level of support. -class GatherExpander : public HloModulePass { +class GatherExpander : public OpExpanderPass { public: absl::string_view name() const override { return "gather_expander"; } - StatusOr Run(HloModule* module) override; protected: - StatusOr ExpandGather(HloInstruction* gather_instr); + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + StatusOr ExpandInstruction( + HloInstruction* gather_inst) override; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index a3102368cb1dba15da7422337666d278cef775ab..e1ea5c39d58b6d23b076740626ca0ad63dc341ee 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -89,7 +89,7 @@ ENTRY main { // an implementation detail from WhileUtil::MakeCountedLoop). const Shape& while_shape = while_instr->shape(); - ASSERT_TRUE(ShapeUtil::IsTuple(while_shape)); + ASSERT_TRUE(while_shape.IsTuple()); ASSERT_EQ(ShapeUtil::TupleElementCount(while_shape), 4); EXPECT_TRUE(ShapeUtil::SameDimensions( diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index bec02e14f951c6d905b7329be5c02896984279d0..cb43c27be961262bf29d4a3958de62cfada19aed 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -83,7 +82,7 @@ Status GenericTransferManager::TransferLiteralFromDeviceInternal( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { TF_RETURN_IF_ERROR(executor->SynchronousMemcpyD2H( /*source=*/device_buffer.buffer(index), /*size=*/GetByteSizeRequirement(subshape), @@ -120,7 +119,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( device_buffer.on_host_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); - if (ShapeUtil::IsArray(device_subshape)) { + if (device_subshape.IsArray()) { TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); // Element is array-shaped: transfer array data to device buffer. diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index bfd1b6cb1492f5cb709e2ecefe73782094e26f5e..25c4f70d89b4ebc483a61f1e28c7a55eb31f4bdf 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -3,6 +3,11 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_test") licenses(["notice"]) # Apache 2.0 @@ -24,12 +29,6 @@ filegroup( ]), ) -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load( - "//tensorflow/core:platform/default/build_config_root.bzl", - "tf_cuda_tests_tags", -) - xla_proto_library( name = "backend_configs", srcs = ["backend_configs.proto"], @@ -94,8 +93,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -135,6 +134,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm//:core", @@ -263,7 +264,9 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -302,6 +305,7 @@ cc_library( "sequential_thunk.cc", "thunk.cc", "thunk_schedule.cc", + "triangular_solve_thunk.cc", "tuple_thunk.cc", "while_thunk.cc", ], @@ -321,6 +325,7 @@ cc_library( "sequential_thunk.h", "thunk.h", "thunk_schedule.h", + "triangular_solve_thunk.h", "tuple_thunk.h", "while_thunk.h", ], @@ -361,7 +366,10 @@ cc_library( "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", + "//tensorflow/stream_executor:blas", + "//tensorflow/stream_executor:device_memory", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -392,18 +400,21 @@ cc_library( srcs = ["cudnn_conv_algorithm_picker.cc"], hdrs = ["cudnn_conv_algorithm_picker.h"], deps = [ + ":autotuning_proto", ":backend_configs", ":buffer_comparator", ":cudnn_conv_runner", ":gpu_executable", ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "//tensorflow/core:logger", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -551,6 +562,44 @@ cc_library( ], ) +cc_library( + name = "gpu_sanitize_constant_names", + srcs = ["gpu_sanitize_constant_names.cc"], + hdrs = ["gpu_sanitize_constant_names.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "gpu_sanitize_constant_names_test", + srcs = ["gpu_sanitize_constant_names_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_sanitize_constant_names", + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "fusion_merger", srcs = ["fusion_merger.cc"], @@ -675,6 +724,7 @@ cc_library( ":gpu_hlo_schedule", ":gpu_hlo_support_checker", ":gpu_layout_assignment", + ":gpu_sanitize_constant_names", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", @@ -694,6 +744,9 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", + "//tensorflow/compiler/xla/service:convolution_group_converter", + "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -711,6 +764,8 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:sort_simplifier", + "//tensorflow/compiler/xla/service:stable_sort_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", @@ -724,6 +779,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor/cuda:cuda_diagnostics", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1004,14 +1060,10 @@ cc_library( srcs = ["variadic_op_splitter.cc"], hdrs = ["variadic_op_splitter.h"], deps = [ - ":ir_emission_utils", - "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "@com_google_absl//absl/strings", @@ -1037,3 +1089,12 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) + +xla_proto_library( + name = "autotuning_proto", + srcs = ["autotuning.proto"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/autotuning.proto b/tensorflow/compiler/xla/service/gpu/autotuning.proto new file mode 100644 index 0000000000000000000000000000000000000000..b4a08963b4f2ebc55c89ed57325093536f343bd1 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/autotuning.proto @@ -0,0 +1,81 @@ +// This file defines protos that store the results of autotuning XLA:GPU +// operations. +// +// They are in proto format because we want to log them structured. They offer +// tremendous statistical, testing, and debugging value. +syntax = "proto3"; + +package xla.gpu; + +import "google/protobuf/duration.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; +import "tensorflow/compiler/xla/service/hlo.proto"; + +message CudnnVersion { + int32 major = 1; + int32 minor = 2; + int32 patch = 3; +} + +message ComputeCapability { + int32 major = 1; + int32 minor = 2; +} + +message AutotuneResult { + message SuccessResult { + int64 scratch_bytes = 1; + google.protobuf.Duration run_time = 2; + } + + message ConvKey { + int64 algorithm = 1; + bool tensor_ops_enabled = 2; + } + + // If the conv runs successfully, success will be populated with the + // autotuning result. Otherwise, the error message is propagated. + oneof result { + SuccessResult success = 3; + string error_string = 4; + } + + oneof key { + ConvKey conv = 5; + } + + // Sometimes we run a correctness checker during autotuning. It compares the + // result buffer content between two algorithms, say, "reference" and "test" + // algorithms. The "test" algorithm is the one associated with this + // AutotuneResult. + // + // This field records the reference algorithm used. Notice that naming it + // "reference" doesn't mean it's always correct. However, empirically it's + // more correct, as it's "algo 0", less fancy than the compared one. + // + // Notice that the checker_failure may exist even in the success case. + // This is because the error string in `result` comes from the underlying + // implementation like cuDNN, which isn't aware that it produced an incorrect + // result. And even if the checker detects an incorrect result, we can still + // retrieve scratch_bytes and runtime_ms. + oneof checker_failure { + ConvKey reference_conv = 6; + } +} + +message AutotuneLog { + message Instruction { + xla.HloInstructionProto instruction = 1; + repeated xla.ShapeProto operand_shapes = 2; + } + + oneof instr_oneof { + Instruction instr = 1; + } + + // Records all auto-tuning results per algorithm. + repeated AutotuneResult results = 3; + + CudnnVersion cudnn_version = 4; + ComputeCapability compute_capability = 5; +} diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 528209abc75777440163c2e1512658b8ad36315b..eb59ee5a1d47b6b706ef3f53a76069b3538eb6b7 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -57,16 +58,16 @@ StatusOr> BufferAllocations::Builder::Build( // If buffer #i's address is already registered (e.g. external arguments or // result buffers), use that registered buffer. - if (registered_buffers_.count(i)) { - se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i); - if (reinterpret_cast(address.opaque()) % expected_alignment != + if (se::DeviceMemoryBase* address = + tensorflow::gtl::FindOrNull(registered_buffers_, i)) { + if (reinterpret_cast(address->opaque()) % expected_alignment != 0) { return InternalError( "Address of registered buffer %d must be a multiple of %x, but " "was %p", - i, kEntryParameterAlignBytes, address.opaque()); + i, kEntryParameterAlignBytes, address->opaque()); } - buffer_allocations->SetBuffer(i, FindOrDie(registered_buffers_, i)); + buffer_allocations->SetBuffer(i, *address); continue; } diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index 14186b8faa68ad8492ea4863fcd7bd746e2eae48..9413ac2cff7c8d3ec4be6662569c580060bf1173 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -52,7 +53,8 @@ class BufferAllocations { DeviceMemoryAllocator* memory_allocator); private: - std::map registered_buffers_; + absl::flat_hash_map + registered_buffers_; }; ~BufferAllocations(); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc index 60289506524759580dbb9b82147c78c4ce1cb25e..2cceb0422d08ff7951308b0727941f5437785447 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc @@ -188,13 +188,8 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { computation_->AddInstruction(HloInstruction::CreateBroadcast( batch_norm->operand(3)->shape(), epsilon, {})))); HloInstruction* inverse_stddev = - computation_->AddInstruction(HloInstruction::CreateBinary( - var_plus_epsilon->shape(), HloOpcode::kPower, var_plus_epsilon, - computation_->AddInstruction(HloInstruction::CreateBroadcast( - var_plus_epsilon->shape(), - computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(-.5))), - {})))); + computation_->AddInstruction(HloInstruction::CreateUnary( + var_plus_epsilon->shape(), HloOpcode::kRsqrt, var_plus_epsilon)); std::vector operands(batch_norm->operands().begin(), batch_norm->operands().end()); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index 6d6780fa1c7b0c636eb771c40e74f074cd8c4c4b..603af5a654589e0b02c762b57d70a8b7628b1d0f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -16,14 +16,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/time/time.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/logger.h" #include "tensorflow/core/platform/mutex.h" namespace xla { @@ -32,7 +35,6 @@ namespace { using absl::optional; using se::DeviceMemoryBase; -using se::dnn::AlgorithmConfig; using se::dnn::AlgorithmDesc; class ScratchAllocator : public se::ScratchAllocator { @@ -132,6 +134,31 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { return tensorflow::mutex_lock{it->second}; } +xla::gpu::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) { + xla::gpu::CudnnVersion cudnn_version; + if (auto* dnn = stream_executor->AsDnn()) { + StatusOr version_or = dnn->GetVersion(); + if (version_or.ok()) { + const auto& version = version_or.ValueOrDie(); + cudnn_version.set_major(version.major_version()); + cudnn_version.set_minor(version.minor_version()); + cudnn_version.set_patch(version.patch()); + } + } + return cudnn_version; +} + +xla::gpu::ComputeCapability GetComputeCapability( + se::StreamExecutor* stream_executor) { + xla::gpu::ComputeCapability cc; + int cc_major, cc_minor; + stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + cc.set_major(cc_major); + cc.set_minor(cc_minor); + return cc; +} + } // anonymous namespace // We could have caching here so that we don't redo this work for two identical @@ -145,8 +172,8 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // cache misses and doing extra work. Overall, caching doesn't seem worth the // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. -StatusOr -CudnnConvAlgorithmPicker::PickBestAlgorithm(HloCustomCallInstruction* instr) { +StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( + const HloCustomCallInstruction* instr) { // TODO(timshen): for now only check fp16. It can be expanded to other types, // with some work on the HLO routines. const bool cross_check_enabled = @@ -232,8 +259,6 @@ CudnnConvAlgorithmPicker::PickBestAlgorithm(HloCustomCallInstruction* instr) { &stream, ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0)))); initialize_buffer(result_buffer); - se::dnn::ProfileResult best_result; - int64 best_result_bytes_used = 0; TF_ASSIGN_OR_RETURN(auto backend_config, instr->backend_config()); @@ -243,82 +268,119 @@ CudnnConvAlgorithmPicker::PickBestAlgorithm(HloCustomCallInstruction* instr) { // this algorithm considered correct, though. optional first_algorithm; TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr)); + std::vector profile_results; for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { ScratchAllocator scratch_allocator(device_ordinal, allocator); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - backend_config.set_algorithm(alg.algo_id()); - backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled()); - TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config)); - bool launch_ok = + // Use assignment instead of brace-list to make GCC 4.9 happy. + RunConvOptions options; + options.profile_result = &profile_result; + options.algo_override = alg; + Status launch_status = RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer, - &scratch_allocator, &stream, &profile_result) - .ok(); - - if (launch_ok && profile_result.is_valid()) { - const bool crash_on_checking_failure = - instr->GetModule() - ->config() - .debug_options() - .xla_gpu_crash_on_verification_failures(); - if (comparator.has_value()) { - StatusOr result = comparator->CompareEqual( - se::DeviceMemory(result_buffer)); - if (!result.ok()) { - LOG(ERROR) << "Unable to compare " - << AlgorithmToString(*first_algorithm) << " against " - << AlgorithmToString(alg) << " for " << instr->ToString() - << ": " << result.status(); - CHECK(!crash_on_checking_failure); - } else if (!result.ValueOrDie()) { - LOG(ERROR) << "Results mismatch between different convolution " - "algorithms. This is likely a bug in convolution, or " - "an excessive loss of precision in convolution. " - << instr->ToString() << " for " - << AlgorithmToString(*first_algorithm) << " vs " - << AlgorithmToString(alg); - CHECK(!crash_on_checking_failure); - } - } else if (cross_check_enabled) { - auto comp = F16BufferComparator::Create( - se::DeviceMemory(result_buffer), compiler_, allocator, - &stream); - if (comp.ok()) { - comparator.emplace(comp.ConsumeValueOrDie()); - first_algorithm.emplace(alg); - } else { - LOG(ERROR) << "Fail to initialize buffer comparator: " - << comp.status() << ", instruction: " << instr->ToString(); - CHECK(!crash_on_checking_failure); - } + &scratch_allocator, &stream, options); + + profile_results.emplace_back(); + AutotuneResult& result = profile_results.back(); + result.mutable_conv()->set_algorithm(alg.algo_id()); + result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled()); + + if (!launch_status.ok()) { + result.set_error_string(launch_status.error_message()); + continue; + } + + if (!profile_result.is_valid()) { + result.set_error_string("Invalid profile result"); + continue; + } + + int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); + result.mutable_success()->set_scratch_bytes(scratch_bytes_used); + *result.mutable_success()->mutable_run_time() = + protobuf_util::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); + + const bool crash_on_checking_failure = + instr->GetModule() + ->config() + .debug_options() + .xla_gpu_crash_on_verification_failures(); + + if (comparator.has_value()) { + StatusOr compare_result = comparator->CompareEqual( + se::DeviceMemory(result_buffer)); + if (!compare_result.ok()) { + LOG(ERROR) << "Unable to compare " + << AlgorithmToString(*first_algorithm) << " against " + << AlgorithmToString(alg) << " for " << instr->ToString() + << ": " << compare_result.status(); + CHECK(!crash_on_checking_failure); + } else if (!compare_result.ValueOrDie()) { + LOG(ERROR) << "Results mismatch between different convolution " + "algorithms. This is likely a bug in convolution, or " + "an excessive loss of precision in convolution. " + << instr->ToString() << " for " + << AlgorithmToString(*first_algorithm) << " vs " + << AlgorithmToString(alg); + CHECK(!crash_on_checking_failure); + auto* failure = result.mutable_reference_conv(); + failure->set_algorithm(first_algorithm->algo_id()); + failure->set_tensor_ops_enabled(first_algorithm->tensor_ops_enabled()); } - int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); - VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) - << " succeeded, taking " << profile_result.elapsed_time_in_ms() - << "ms and using " << NumBytesToString(scratch_bytes_used) - << " of scratch (Best result: " - << best_result.elapsed_time_in_ms() << "ms, " - << NumBytesToString(best_result_bytes_used) << " of scratch)"; - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - best_result_bytes_used = scratch_bytes_used; + } else if (cross_check_enabled) { + auto comp = F16BufferComparator::Create( + se::DeviceMemory(result_buffer), compiler_, allocator, + &stream); + if (comp.ok()) { + comparator.emplace(comp.ConsumeValueOrDie()); + first_algorithm.emplace(alg); + } else { + LOG(ERROR) << "Fail to initialize buffer comparator: " << comp.status() + << ", instruction: " << instr->ToString(); + CHECK(!crash_on_checking_failure); } - } else { - VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) << " failed."; } } - if (best_result.is_valid()) { - VLOG(2) << "Best algorithm for " << instr->ToString() << ": " - << AlgorithmToString(best_result.algorithm()) << ", takes " - << best_result.elapsed_time_in_ms() << "ms, and uses " - << best_result_bytes_used << "B of scratch memory."; - return AutotuneResult{best_result.algorithm().algo_id(), - best_result.algorithm().tensor_ops_enabled(), - best_result_bytes_used, - absl::Milliseconds(best_result.elapsed_time_in_ms())}; + + // Log the autotuning result. + { + AutotuneLog log; + *log.mutable_instr()->mutable_instruction() = instr->ToProto(); + for (const auto* op : instr->operands()) { + *log.mutable_instr()->add_operand_shapes() = op->shape().ToProto(); + } + for (const auto& profile : profile_results) { + *log.add_results() = profile; + } + *log.mutable_compute_capability() = GetComputeCapability(stream_exec_); + *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec_); + VLOG(2) << "Autotuning result:\n" << log.DebugString(); + tensorflow::Logger::Singleton()->LogProto(log); + } + + auto* profile_results_end = profile_results.data() + profile_results.size(); + + const AutotuneResult* best_result = std::min_element( + profile_results.data(), profile_results_end, + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + // The successful one should have a smaller key, since we are doing + // min_element. If they are both unsuccessful, keep the earlier one in + // the vector by comparing pointers. + return std::make_tuple( + !lhs.has_success(), + protobuf_util::FromDurationProto(lhs.success().run_time()), + &lhs) < std::make_tuple(!rhs.has_success(), + protobuf_util::FromDurationProto( + rhs.success().run_time()), + &rhs); + }); + + if (best_result != profile_results_end && best_result->has_success()) { + return *best_result; } return InternalError( @@ -339,22 +401,23 @@ StatusOr CudnnConvAlgorithmPicker::RunOnInstruction( } auto best_algo = std::move(best_algo_or).ValueOrDie(); - VLOG(1) << "Setting cudnn conv to use algorithm " << best_algo.algorithm - << " and " << NumBytesToString(best_algo.scratch_bytes) + VLOG(1) << "Setting cudnn conv to use algorithm " + << best_algo.conv().algorithm() << " and " + << NumBytesToString(best_algo.success().scratch_bytes()) << " of scratch memory: " << instr->ToString() - << " tensor_ops_enabled: " << best_algo.tensor_ops_enabled; + << " tensor_ops_enabled: " << best_algo.conv().tensor_ops_enabled(); // Replace instr with a new CustomCall which has the correct algorithm, and // whose output shape has the appropriate amount of scratch memory. HloComputation* computation = instr->parent(); Shape new_call_shape = ShapeUtil::MakeTupleShape( {instr->shape().tuple_shapes(0), - ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes})}); + ShapeUtil::MakeShape(U8, {best_algo.success().scratch_bytes()})}); TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, instr->backend_config()); - backend_config.set_algorithm(best_algo.algorithm); - backend_config.set_tensor_ops_enabled(best_algo.tensor_ops_enabled); + backend_config.set_algorithm(best_algo.conv().algorithm()); + backend_config.set_tensor_ops_enabled(best_algo.conv().tensor_ops_enabled()); HloInstruction* new_call = computation->AddInstruction( instr->CloneWithNewOperands(new_call_shape, instr->operands())); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h index 642af787afc71586d722ecc7e529ed8b3fa64d33..2e34ba9672314a62290b8a557960a605a98996c7 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/autotuning.pb.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -47,16 +48,10 @@ class CudnnConvAlgorithmPicker : public HloModulePass { StatusOr Run(HloModule* module) override; private: - struct AutotuneResult { - int64 algorithm; - bool tensor_ops_enabled; - int64 scratch_bytes; - absl::Duration runtime; - }; - StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - StatusOr PickBestAlgorithm(HloCustomCallInstruction* instr); + StatusOr PickBestAlgorithm( + const HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc index 5aa4f839f4be5f1060480fea98775f8ffada0bdd..958e0b9c6e7b7885f87b90d61ee5b3bbf6ab2702 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc @@ -50,10 +50,10 @@ static HloInstruction* PadInstruction(HloInstruction* instr, auto* zero = comp->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); - PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); + PaddingConfig pad_config = MakeNoPaddingConfig(shape.rank()); bool added_padding = false; - for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) { + for (int64 dim = 0; dim < shape.rank(); ++dim) { if (shape.dimensions(dim) == new_shape.dimensions(dim)) { continue; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc index 3a09d4d4716950a09d65dd093272482d55ac5c27..17d0f7aa7bf6031148aae79f74f7878d6fca9574 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc @@ -219,7 +219,7 @@ bool CudnnConvPaddingLegalization::CanonicalizeBackwardFilterConvolution( Window new_backward_conv_window = backward_conv->window(); // input_padding_config is the config of the kPad to be inserted. PaddingConfig input_padding_config = - MakeNoPaddingConfig(ShapeUtil::Rank(input->shape())); + MakeNoPaddingConfig(input->shape().rank()); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc index 443883a89f66a747def1049bc5afb53fec3c2409..dbcdc2b075bc72f3194af8e555faabb1511376e0 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -109,9 +109,11 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolve) { auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape( activations->shape(), gradients->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_filter_) + /*batch_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, /*feature_group_count=*/1, conv_window, + activations, gradients, /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); OpMetadata metadata; @@ -147,9 +149,11 @@ TEST_F(CudnnConvRewriterTest, builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape( activations->shape(), gradients->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_filter_) + /*batch_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, /*feature_group_count=*/1, conv_window, + activations, gradients, /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -179,7 +183,7 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -209,7 +213,7 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -238,7 +242,7 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -283,13 +287,15 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveEvenPadding) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output, - /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window, - conv_dnums, DefaultPrecisionConfig(2))); + /*rhs=*/reverse_kernel, /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, conv_dnums, + DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( conv->shape(), ShapeInference::InferConvolveShape( output->shape(), reverse_kernel->shape(), - /*feature_group_count=*/1, conv_window, conv_dnums) + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, conv_dnums) .ValueOrDie())); auto module = CreateNewVerifiedModule(); @@ -332,10 +338,12 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolve1x1Filter) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window, + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -365,11 +373,12 @@ TEST_F(CudnnConvRewriterTest, builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape( output->shape(), kernel->shape(), /*feature_group_count=*/1, - default_conv_window_, tf_default_dnums_for_backward_input_) + /*batch_group_count=*/1, default_conv_window_, + tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, - default_conv_window_, tf_default_dnums_for_backward_input_, - DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, default_conv_window_, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -415,15 +424,15 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -465,15 +474,15 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -519,15 +528,15 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { forward_conv_col_dim->set_base_dilation(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewVerifiedModule(); const HloComputation* entry_computation = @@ -574,15 +583,15 @@ TEST_F(CudnnConvRewriterTest, forward_conv_col_dim->set_padding_high(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -599,7 +608,7 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { Array4D constant_arr(4, 4, 2, 2); constant_arr.FillIota(0); string constant_str = - LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); + LiteralUtil::CreateR4FromArray4D(constant_arr).ToStringWithoutShape(); const string module_str = absl::StrFormat(R"( HloModule test diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc index 3425e1b4942aaf1011ba1bf1c50dd7e79c1f9807..b628f27f4b2ba8ccf17fd531d8a0c25cb99d9396 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -395,32 +395,36 @@ Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { + RunConvOptions options) { ScratchBufAllocator scratch_allocator(scratch_buf); return RunCudnnConv(conv, operand_buffers, result_buffer, &scratch_allocator, - stream, profile_result); + stream, options); } Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { + RunConvOptions options) { TF_ASSIGN_OR_RETURN(CudnnConvParams params, GetCudnnConvParams(conv, operand_buffers, result_buffer)); + if (options.algo_override) { + params.algorithm = AlgorithmConfig(*options.algo_override); + } + PrimitiveType output_primitive_type = conv->shape().tuple_shapes(0).element_type(); switch (output_primitive_type) { case F16: return RunCudnnConvImpl(params, scratch_allocator, stream, - profile_result); + options.profile_result); case F32: return RunCudnnConvImpl(params, scratch_allocator, stream, - profile_result); + options.profile_result); case F64: return RunCudnnConvImpl(params, scratch_allocator, stream, - profile_result); + options.profile_result); default: LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h index edbc75a94a1238540390b93f0fa5217852c7781f..25b2461ca61251c6cb7b89b1f91da0f1636a3647 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h @@ -28,6 +28,14 @@ limitations under the License. namespace xla { namespace gpu { +struct RunConvOptions { + // Nullable output-parameter pointer for profiling results. + se::dnn::ProfileResult* profile_result = nullptr; + + // Use this algorithm, instead of the one from the instruction. + absl::optional algo_override; +}; + // This file contains low-level routines for running cudnn convolutions. // Calls into cudnn to run the specified convolution. @@ -46,13 +54,13 @@ Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); + RunConvOptions = {}); Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); + RunConvOptions = {}); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 2ab754a471070d5f90a3eaebd0600ff180d2fe5d..dd74788a0e2940e88dfca1ffa4a4cdad7c1997e2 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -270,6 +270,16 @@ StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, prim_type); } +StatusOr GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type, + llvm::Value* value) { + return EmitLibdeviceMathCall("__nv_sqrt", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type, + llvm::Value* value) { + return EmitLibdeviceMathCall("__nv_rsqrt", {value}, {prim_type}, prim_type); +} + StatusOr GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) { @@ -308,9 +318,11 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( false); // No variadic arguments. // Declares the callee if it is not declared already. - llvm::Function* callee = llvm::cast( - b_->GetInsertBlock()->getModule()->getOrInsertFunction( - llvm_ir::AsStringRef(callee_name), callee_type)); + llvm::Function* callee = llvm::dyn_cast( + b_->GetInsertBlock() + ->getModule() + ->getOrInsertFunction(llvm_ir::AsStringRef(callee_name), callee_type) + .getCallee()); for (auto attribute : attributes) { callee->addFnAttr(attribute); @@ -446,7 +458,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( return Load(accum_ptr); }; case HloOpcode::kReduce: - // TODO(b/112040122): This should be supported. + // TODO(b/118332391): This should be supported. CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce"; return [=, &operand_to_generator]( const IrArray::Index& output_index) -> StatusOr { diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index e8b56a39ce58b6aab35c1c977553c7ff7e753273..2aedbf05abb31c88b9988dc1d90e921e9473d25b 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -76,6 +76,12 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitExpm1(PrimitiveType prim_type, llvm::Value* value) override; + StatusOr EmitSqrt(PrimitiveType prim_type, + llvm::Value* value) override; + + StatusOr EmitRsqrt(PrimitiveType prim_type, + llvm::Value* value) override; + StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) override; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 470457935acacb8940af241dadb393d770786939..91930eccdff94bb2fc85636f3a4b2d661c618d87 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -35,7 +35,7 @@ namespace { // Traverses users of tuple shape, adding leaf instructions to 'instructions'. void MaybeResolveTupleElements(HloInstruction* instruction, std::vector* instructions) { - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { for (auto tuple_user : instruction->users()) { MaybeResolveTupleElements(tuple_user, instructions); } diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 27f07b1d58125092c1ed6734b238e4ae0f11c4aa..a7053e6a013be3ccf5725cbe003558be77104af1 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -206,6 +206,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { return &DoGemm; case C64: return &DoGemm>; + case C128: + return &DoGemm>; default: LOG(FATAL) << "Unsupported type."; } @@ -221,6 +223,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type) return &DoGemmWithAlgorithm; case C64: return &DoGemmWithAlgorithm>; + case C128: + return &DoGemmWithAlgorithm>; default: LOG(FATAL) << "Unsupported type."; } @@ -235,6 +239,8 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { return &DoGemmAutotune; case C64: return &DoGemmAutotune>; + case C128: + return &DoGemmAutotune>; default: LOG(FATAL) << "Unsupported type."; } @@ -255,6 +261,8 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { return se::blas::ComputationType::kF64; case C64: return se::blas::ComputationType::kComplexF32; + case C128: + return se::blas::ComputationType::kComplexF64; default: LOG(FATAL) << "Unsupported type."; } @@ -315,8 +323,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), dim_nums.rhs_batch_dimensions_size()); - CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, - ShapeUtil::Rank(output_shape_)); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape_.rank()); int64 row_dim = dim_nums.lhs_batch_dimensions_size(); int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1; @@ -421,7 +428,8 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, scratch_data = scratch_mem->device_memory(); } const MatrixDescriptor scratch_descriptor( - scratch_data, false, output_num_cols, output_num_rows, batch_size); + scratch_data, false, output_matrix.num_rows, output_matrix.num_cols, + batch_size); StatusOr best_algorithm = GetGemmAutotuneFn( element_type)(lhs_matrix, rhs_matrix, scratch_descriptor, alpha_, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index ae2e718db29803a085401969a7d9b09abf690a6c..434060ad89dac7ad65c790c8c0a7f3d6ad62a25a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -218,7 +218,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation); - CHECK(ShapeUtil::IsArray(literal.shape())); + CHECK(literal.shape().IsArray()); if (!ShouldEmitLiteralInLlvmIr(literal)) { VLOG(3) << "H2D memcpy for constant with shape " << ShapeUtil::HumanString(literal.shape()); @@ -310,12 +310,34 @@ StatusOr GpuExecutable::ExecuteOnStream( TF_ASSIGN_OR_RETURN( const BufferAllocation::Slice slice, this->assignment_->GetUniqueSlice(src_hlo, sources[0]->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); se::DeviceMemoryBase src_base = buffer_allocations->GetDeviceAddress(slice.index()); CHECK(!src_base.is_null() || src_base.size() == 0); - *device_memory = src_base; + if (!slice.allocation()->is_entry_computation_parameter()) { + // If the buffer coming out of the result is from a parameter, it + // means the caller aliased some parameter buffer to an output one + // (via the HloInputOutputAliasConfig API). If that is the case, the + // caller will receive a partially complete scoped shaped buffer, + // which they will have to fill up on return. + // Unfortunately the interface to the execute APIs are ShapedBuffer + // pointer based, which assumes caller ownership, and hence a buffer + // coming from there cannot be part of the new ScopedShapedBuffer we + // create for the result (which assumes ownership). + *device_memory = src_base; + } else { + const HloInputOutputAliasConfig& input_output_alias = + module().input_output_alias_config(); + auto output_alias = input_output_alias.GetAliasedOutput( + slice.allocation()->parameter_number(), + slice.allocation()->param_shape_index()); + CHECK(output_alias) + << "Ouput buffer is coming from parameter " + << slice.allocation()->parameter_number() << " at index " + << slice.allocation()->param_shape_index() + << ", but no alias exists"; + CHECK_EQ(*output_alias, index); + } buffers_in_result.insert(src_base); return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 452e763a8eaadc805cd3a3859a68e2a31598fd36..842ba2fdcd31a451cec1be543e102e0a46077f38 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -42,15 +42,13 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, int64 max_rank = -1; const Layout* max_rank_layout; for (HloInstruction* param : params) { - if (ShapeUtil::IsArray(param->shape()) && - ShapeUtil::Rank(param->shape()) > max_rank) { - max_rank = ShapeUtil::Rank(param->shape()); + if (param->shape().IsArray() && param->shape().rank() > max_rank) { + max_rank = param->shape().rank(); max_rank_layout = ¶m->shape().layout(); } } return absl::c_all_of(params, [&](HloInstruction* param) { - return (!ShapeUtil::IsArray(param->shape())) || - (ShapeUtil::Rank(param->shape()) < max_rank) || + return (!param->shape().IsArray()) || (param->shape().rank() < max_rank) || (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); }); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index e9d7ba1c4cfa865532a0d06c2ed883a2fea4e2cd..9f0de3f794decb7b878b67c96030f8e11b0555fe 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -48,7 +48,7 @@ bool IsInputFusibleReduction(const HloInstruction& instr); // Whether instruction shapes are compatible for multi-output fusion, i.e. // whether the emitters support lowering the resulting fusion. -// This function works for both, sibling and producer-conumser multi-output +// This function works for both, sibling and producer-consumer multi-output // fusion. // So far, multi-output fusion is supported for loop fusions and reduce // input fusions only. It is up to the caller to ensure the instructions diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc index 4268fb2c7a813b3b53e4cd48746028a7b369f28e..4765f67c4b17e97419182e341573f75ad3d6ac30 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index f59da2caa18646676297e66dd329c66fb5fddf1b..a6d80f0b6dddb3d8d0fd00c639e11c71da6a9f09 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -196,9 +196,9 @@ Status GpuLayoutAssignment::AddBackendConstraints( CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), dim_nums.rhs_batch_dimensions_size()); CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, - ShapeUtil::Rank(instruction->shape())); + instruction->shape().rank()); for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { - CHECK_LT(batch_dim, ShapeUtil::Rank(instruction->shape()) - 2); + CHECK_LT(batch_dim, instruction->shape().rank() - 2); } // Set both inputs and the output to default layout. @@ -215,18 +215,18 @@ Status GpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, instruction)); } else if (instruction->opcode() == HloOpcode::kSort && - ShapeUtil::Rank(instruction->operand(0)->shape()) > 1) { + instruction->operand(0)->shape().rank() > 1) { // Make sure that all the operands and the output(s) have the same layout. Shape keys_shape = instruction->operand(0)->shape(); Layout keys_layout = - LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(keys_shape)); + LayoutUtil::GetDefaultLayoutForRank(keys_shape.rank()); for (int64 i = 0; i < instruction->operand_count(); ++i) { Shape shape = instruction->operand(i)->shape(); *shape.mutable_layout() = keys_layout; TF_RETURN_IF_ERROR( constraints->SetOperandLayout(shape, instruction, i)); const LogicalBuffer* output_buffer; - if (ShapeUtil::IsArray(instruction->shape())) { + if (instruction->shape().IsArray()) { TF_ASSIGN_OR_RETURN( output_buffer, constraints->points_to_analysis().GetBufferDefinedAt(instruction, @@ -240,6 +240,32 @@ Status GpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR( constraints->SetBufferLayout(keys_layout, *output_buffer)); } + } else if (instruction->opcode() == HloOpcode::kTriangularSolve) { + // TODO(phawkins): Ideally we would relax this constraint. What we + // actually want is that: + // a) the batch dimensions are major, in no particular order. + // b) the two minor dimensions are in fortran (column-major) order, + // although for the 'a' argument we could potentially accept row-major + // order and fold the transpose into the operator. + auto set_fortran_layout = [](Shape* shape) { + LayoutUtil::SetToDefaultLayout(shape); + int n = shape->mutable_layout()->minor_to_major_size(); + CHECK_GE(n, 2); + std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0), + shape->mutable_layout()->mutable_minor_to_major()->at(1)); + }; + Shape op0_shape = instruction->operand(0)->shape(); + Shape op1_shape = instruction->operand(1)->shape(); + Shape output_shape = instruction->shape(); + set_fortran_layout(&op0_shape); + set_fortran_layout(&op1_shape); + set_fortran_layout(&output_shape); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op0_shape, instruction, 0)); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op1_shape, instruction, 1)); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(output_shape, instruction)); } } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 2ffc8bfb49b205dced0d540ba72426e72d95e596..391029e574622925b2a7e801a7d41d95e49a1cfb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -368,12 +368,21 @@ TEST_F(LayoutAssignmentTest, DotLayout) { TEST_F(LayoutAssignmentTest, SortLayout) { const char* hlo_text = R"( HloModule SortLayout + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = f32[] parameter(2) + p.1.rhs = f32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + ENTRY sort { - keys = f32[3,2]{0,1} constant(f32[3,2]{0,1}{{0,1},{0,1},{0,1}}) + keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}}) values = f32[2,3]{1,0} parameter(0) transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0} ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose), - dimensions={1} + dimensions={1}, to_apply=compare })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.cc b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e38ceca18de30e0e1fa75a7a4bd865e000b7d22 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.cc @@ -0,0 +1,70 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace gpu { + +StatusOr GpuSanitizeConstantNames::Run(HloModule* module) { + bool changed = false; + + NameUniquer instr_name_uniquer(/*separator=*/"_"); + // Collect the names used for the non-constant HLO instructions.+ + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kConstant) { + continue; + } + + const string& old_name = instr->name(); + instr->UniquifyName(&instr_name_uniquer); + CHECK_EQ(old_name, instr->name()); + } + } + + // Sanitize the names for the constant HLO instructions and make them unique. + // This is not merged into the above loop because we don't want this pass to + // change the names of non-constant instructions, that is, if a constant HLO + // conflicts with a non-constant HLO, we change the name of the constant HLO + // even though the non-constant HLO comes after in the HLO module. + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instr : computation->instructions()) { + if (instr->opcode() != HloOpcode::kConstant) { + continue; + } + string sanitized_name = llvm_ir::SanitizeConstantName(*instr); + instr->SetAndSanitizeName(sanitized_name); + instr->UniquifyName(&instr_name_uniquer); + changed = true; + } + } + + return changed; +} // namespace gpu + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h new file mode 100644 index 0000000000000000000000000000000000000000..8d583d047e25698e86032020b7fc20df87f5ab68 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Sanitizes HLO instruction names for the GPU backend. Currently, it only +// replaces . and - in the HLO constant instruction names with _ to please the +// LLVM PTX backend. +class GpuSanitizeConstantNames : public HloModulePass { + public: + absl::string_view name() const override { return "sanitize-constant-names"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5adee8cc61f18f356406d8c089dd43565957739 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; +using SanitizeConstantNamesTest = HloTestBase; + +TEST_F(SanitizeConstantNamesTest, InstructionNameWithHyphenSanitized) { + const char *const kHloString = R"( + HloModule HyphenInInstructionName + ENTRY kernelEntry { + ROOT equal-to = s32[2]{0} constant({42, 73}) + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(kHloString)); + + EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).ValueOrDie()); + HloInstruction *root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->name(), "equal_to"); +} + +TEST_F(SanitizeConstantNamesTest, InstructionNameWithDotSanitized) { + const char *const kHloString = R"( + HloModule HyphenInInstructionName + ENTRY kernelEntry { + ROOT equal.to = s32[2]{0} constant({42, 73}) + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(kHloString)); + + EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).ValueOrDie()); + HloInstruction *root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->name(), "equal_to"); +} + +TEST_F(SanitizeConstantNamesTest, BufferSanitizedNameCollisionResolved) { + const char *const kHloString = R"( + HloModule BufferSanitizedName + ENTRY kernelEntry { + equal.to = s32[2]{0} constant({42, 73}) + equal-to = s32[2]{0} constant({67, 3}) + ROOT equal_to = s32[2]{0} add(equal.to, equal-to) + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(kHloString)); + + EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).ValueOrDie()); + EXPECT_THAT(FindInstruction(module.get(), "equal_to_1"), op::Constant()); + EXPECT_THAT(FindInstruction(module.get(), "equal_to_2"), op::Constant()); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index f3c274429242d5c989146d14ea523b5910408cff..e593f535642e15f28a4a1c1f321881ba3c694548 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" @@ -59,7 +58,7 @@ Status GpuTransferManager::TransferLiteralToInfeed( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( shape, [&](const Shape& literal_subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(literal_subshape)) { + if (literal_subshape.IsArray()) { int64 tuple_element_size = GetByteSizeRequirement(literal_subshape); TF_ASSIGN_OR_RETURN( *buffer_tree.mutable_element(index), @@ -126,13 +125,12 @@ static void ShapeTreeToLiteral( ShapeTree>* shape_tree, ShapeIndex* index) { const Shape& shape = ShapeUtil::GetSubshape(shape_tree->shape(), *index); - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { (*shape_tree->mutable_element(*index))->WaitUntilAvailable(); return; } - CHECK(ShapeUtil::IsTuple(shape)) - << ShapeUtil::HumanStringWithLayout(shape); + CHECK(shape.IsTuple()) << ShapeUtil::HumanStringWithLayout(shape); const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape); index->push_back(0); for (int64 i = 0; i < tuple_element_count; ++i) { @@ -158,7 +156,7 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( std::unique_ptr* buffer) { const Shape& shape = ShapeUtil::GetSubshape(literal_shape, index); // Do not transfer tuple index buffers. - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { return; } *buffer = absl::make_unique( diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 51627402b45f594dab3480129ba182d54d01b811..69aaaceca112364a4fd562f6a5eff1629fd3fc54 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -45,10 +46,10 @@ void HloToIrBindings::EmitBasePointersForHlos( // An HLO can have duplicated operands. This data structure remembers which // operand HLOs are already bound to avoid rebinding the same HLO. - std::set already_bound_for_this_function; + absl::flat_hash_set already_bound_for_this_function; auto arg_iter = function->arg_begin(); for (const HloInstruction* io_hlo : io_hlos) { - if (!already_bound_for_this_function.count(io_hlo)) { + if (!already_bound_for_this_function.contains(io_hlo)) { if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) { BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter)); } else { @@ -63,7 +64,7 @@ void HloToIrBindings::EmitBasePointersForHlos( temp_buffer_base_->setName("temp_buffer"); for (const HloInstruction* non_io_hlo : non_io_hlos) { - if (already_bound_for_this_function.count(non_io_hlo)) { + if (already_bound_for_this_function.contains(non_io_hlo)) { continue; } already_bound_for_this_function.insert(non_io_hlo); @@ -280,7 +281,7 @@ string HloToIrBindings::ToString() const { StrAppend(&s, " ", instr->ToString()); const ShapeTree& shape_tree = it->second; - if (!ShapeUtil::IsTuple(instr->shape())) { + if (!instr->shape().IsTuple()) { const llvm::Value* val = shape_tree.begin()->second; StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n"); continue; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index c0edae530cedba45c897b07b7b9cc72eaaab397c..f57b594e9c18078a3bbbf4d2b4db7e989c4edfdd 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" @@ -61,7 +62,7 @@ class HloToIrBindings { // Returns whether `hlo` is bound to an LLVM IR value. bool BoundToIrValue(const HloInstruction& hlo) const { - return base_ptrs_.count(&hlo); + return base_ptrs_.contains(&hlo); } llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; } @@ -110,7 +111,8 @@ class HloToIrBindings { // For an instruction that generates multiple outputs, the root will be a // tuple shape. The IrArray for each element output is stored in the subnode // in the ShapeTree. - std::unordered_map> base_ptrs_; + absl::flat_hash_map> + base_ptrs_; // The address of the memory block that contains all temporary buffers. llvm::Value* temp_buffer_base_ = nullptr; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 8c3a026740851767855beae59d6a3c92f7a0d6bd..676380c3b10f9a20c641eea0d9a948a26becaddc 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -36,6 +36,21 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, ShapeTree infeed_buffers = GetOrCreateInfeedManager()->BlockingGetNextDestination(); + // infeed_slices_'s shape should be a tuple of shape (buffers, token). + const auto& infeed_shape = infeed_slices_.shape(); + TF_RET_CHECK(infeed_shape.IsTuple()) + << ShapeUtil::HumanStringWithLayout(infeed_shape); + TF_RET_CHECK(infeed_shape.tuple_shapes().size() == 2) + << ShapeUtil::HumanStringWithLayout(infeed_shape); + TF_RET_CHECK(infeed_shape.tuple_shapes(1).IsToken()) + << ShapeUtil::HumanStringWithLayout(infeed_shape); + TF_RET_CHECK( + ShapeUtil::Equal(infeed_buffers.shape(), infeed_shape.tuple_shapes(0))) + << "Expected infeed of shape " + << ShapeUtil::HumanStringWithLayout(infeed_shape.tuple_shapes(0)) + << " but was " + << ShapeUtil::HumanStringWithLayout(infeed_buffers.shape()); + { // The infeed buffer has an extra outer tuple with a token. Adjust the index // accordingly. @@ -45,7 +60,7 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, const Shape& shape = ShapeUtil::GetSubshape(infeed_buffers.shape(), ShapeIndexView(index, 1)); // For the leaf buffers of the tuple copy the elements directly. - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { const BufferAllocation::Slice& tuple_element_buffer = infeed_slices_.element(index); se::DeviceMemoryBase tuple_element_address = diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 6151dd8ff4c92bb81bd756c68cc9377633c8c9d5..f07141029cbf8b034b74548f6fca8f1628589f0c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -282,22 +282,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, int64 operand_index) { - const HloInstruction* producer = consumer->operand(operand_index); - // The IR emitter has limited support for non-loop fusions with multi output - // at present. - // TODO(tjoerg): Relax this constraint to allow for arbitraty kinds of fusion. - if (consumer->opcode() == HloOpcode::kFusion && - consumer->fusion_kind() != HloInstruction::FusionKind::kLoop) { - return false; - } - // Multi-output fusion requires instructions with compatible shapes. - if (!ShapeUtil::Compatible(producer->shape(), consumer->shape())) { - return false; - } - // TODO(tjoerg): Stop calling `ShouldFuse` to relax the criteria for - // multi-output fusion. In particular, do not check whether an instruction is - // expensive to duplicate, since this doesn't matter here. - return GpuInstructionFusion::ShouldFuse(consumer, operand_index); + return false; } HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 688604cd36e5a45debf855aacd29d05ecda92341..a05ab86cf77a134a1fc387d93cb482aa1ff5345b 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -506,202 +506,11 @@ TEST_F(InstructionFusionTest, MultiOutputFusion) { })") .ValueOrDie(); - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - SCOPED_TRACE(module->ToString()); - - // Expect that there is one multi-output fusion and subtract has not been - // duplicated. - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); - EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); - TF_ASSERT_OK_AND_ASSIGN( - const HloInstruction* fusion, - FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); - EXPECT_THAT( - fusion->fused_expression_root(), - op::Tuple(op::Add(op::Subtract(), op::Parameter()), op::Subtract())); -} - -TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) { - // tanh --> add --> tuple - // \---------------/ - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - tanh = f32[4,3]{1,0} tanh(p0) - add = f32[4,3]{1,0} add(tanh, p1) - ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(tanh, add) - })") - .ValueOrDie(); - - // TODO(tjoerg): Allow multi-output fusion for expensive operations like tanh. + // Multi-output fusion is disabled here and performed in the + // GpuMultiOutputFusion pass instead. ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) - .ValueOrDie()) - << module->ToString(); -} - -TEST_F(InstructionFusionTest, MultiOutputFusion2) { - // sub --> add1 --\--------\ - // \----------> add2 --> tuple - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - p2 = f32[4,3]{1,0} parameter(2) - sub = f32[4,3]{1,0} subtract(p0, p2) - add1 = f32[4,3]{1,0} add(sub, p1) - add2 = f32[4,3]{1,0} add(sub, add1) - ROOT tuple = (f32[4,3]{1,0}) tuple(add1, add2) - })") - .ValueOrDie(); - - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - SCOPED_TRACE(module->ToString()); - - // Expect that there is one multi-output fusion and subtract has not been - // duplicated. - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); - EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); - TF_ASSERT_OK_AND_ASSIGN( - const HloInstruction* fusion, - FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); - EXPECT_THAT(fusion->fused_expression_root(), - op::Tuple(op::Add(op::Subtract(), op::Add()), - op::Add(op::Subtract(), op::Parameter()))); -} - -TEST_F(InstructionFusionTest, MultiOutputFusion3) { - // sub --> add1 ----\--------\ - // \ --> add2 --> add3 --> tuple - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - p2 = f32[4,3]{1,0} parameter(2) - p3 = f32[4,3]{1,0} parameter(3) - sub = f32[4,3]{1,0} subtract(p0, p2) - add1 = f32[4,3]{1,0} add(sub, p1) - add2 = f32[4,3]{1,0} add(p2, sub) - add3 = f32[4,3]{1,0} add(add1, add2) - ROOT tuple = (f32[4,3]{1,0}) tuple(add3, add2) - })") - .ValueOrDie(); - - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - SCOPED_TRACE(module->ToString()); - - // Expect that there is one multi-output fusion and subtract has not been - // duplicated. - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); - EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); - TF_ASSERT_OK_AND_ASSIGN( - const HloInstruction* fusion, - FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); - EXPECT_THAT(fusion->fused_expression_root(), - op::Tuple(op::Add(op::Add(), op::Add()), - op::Add(op::Parameter(), op::Subtract()))); -} - -TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) { - // sub --> mul ---\ - // \--> call --> add --> tuple - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - c = f32[] constant(42) - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - sub = f32[4,3]{1,0} subtract(p0, p1) - mul = f32[4,3]{1,0} multiply(sub, c) - call = f32[4,3]{1,0} custom-call(sub), custom_call_target="foo" - add = f32[4,3]{1,0} add(mul, call) - ROOT tuple = (f32[4,3]{1,0}) tuple(add) - })") - .ValueOrDie(); - - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - // Visit instructions in post order to detect cycles. - // TODO(tjoerg): Add cycle detection to the HloVerifier. - class DummyVisitor : public DfsHloVisitorWithDefault { - public: - DummyVisitor() {} - Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { - return Status::OK(); - } - } visitor; - for (const HloComputation* computation : module->MakeComputationPostOrder()) { - // Accept will return a FailedPrecondition when a cycle is detected. - EXPECT_TRUE(computation->root_instruction()->Accept(&visitor).ok()); - } -} - -TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) { - // sub[2,3] --> add[4,3] --> tuple([2,3], [4,3]) - // \-------------------------/ - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[2,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - p2 = f32[2,3]{1,0} parameter(2) - sub = f32[2,3]{1,0} subtract(p0, p2) - add = f32[4,3]{1,0} add(sub, p1) - ROOT tuple = (f32[2,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add) - })") - .ValueOrDie(); - - // Multi-output fusion requires shapes to be compatible. Since `sub` and `add` - // have incompatible shapes, expect that no multi-output fusion happens. - ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()) - << module->ToString(); -} - -TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) { - auto module = ParseHloString(R"( - HloModule test_module - - add_computation { - add_lhs = f32[] parameter(0) - add_rhs = f32[] parameter(1) - ROOT add_root = f32[] add(add_lhs, add_rhs) - } - - fused_computation { - p1 = f32[10] parameter(0) - zero = f32[] constant(0) - ROOT f2_root = f32[] reduce(p1, zero), dimensions={0}, - to_apply=add_computation - } - - ENTRY entry { - p0 = f32[10] parameter(0) - mul = f32[10] multiply(p0, p0) - fusion = f32[] fusion(mul), kind=kInput, calls=fused_computation - ROOT tuple = (f32[10], f32[]) tuple(fusion, mul) - })") - .ValueOrDie(); - - // Multi-output fusion is not supported for non-loop fusions at present. Since - // `fused_computation` is a input fusion, expect no multi-output fusion to - // happen. - ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()) - << module->ToString(); + .ValueOrDie()); } TEST_F(InstructionFusionTest, FuseScalarConstant) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 33e41a2782b5932430eea621d3cea2c6634f292f..3ed6553f9205803cfa17772b890c449cfb457c89 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -20,7 +20,6 @@ limitations under the License. #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -40,7 +39,7 @@ namespace { // Return whether the given shape is rank 2 excluding the batch dimensions. bool IsRank2(const Shape& shape, int64 batch_dimensions_size) { - return ShapeUtil::Rank(shape) == batch_dimensions_size + 2; + return shape.rank() == batch_dimensions_size + 2; } // In a gemm operation where output = lhs * rhs, check whether the given shapes @@ -54,7 +53,8 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, PrimitiveType output_primitive_type = output_shape.element_type(); bool type_is_allowed = (output_primitive_type == F16 || output_primitive_type == F32 || - output_primitive_type == F64 || output_primitive_type == C64); + output_primitive_type == F64 || output_primitive_type == C64 || + output_primitive_type == C128); return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) && IsRank2(rhs_shape, batch_dimensions_size) && IsRank2(output_shape, batch_dimensions_size) && @@ -154,20 +154,17 @@ bool IsReductionToVector(const HloInstruction& reduce) { const HloInstruction* input = reduce.operand(0); std::vector dims_to_keep; for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) { - if (!std::count(reduce.dimensions().begin(), reduce.dimensions().end(), - dim)) { + if (!absl::c_linear_search(reduce.dimensions(), dim)) { dims_to_keep.push_back(dim); } } return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), dims_to_keep) && - ShapeUtil::Equal(reduce.shape(), ShapeUtil::FilterDimensions( - [&dims_to_keep](int64 dim) { - return std::count( - dims_to_keep.begin(), - dims_to_keep.end(), dim); - }, - input->shape())); + ShapeUtil::Equal( + reduce.shape(), + ShapeUtil::FilterDimensions( + [&](int64 dim) { return absl::c_count(dims_to_keep, dim); }, + input->shape())); } // This emits a device-side call to diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 6693f66d62d8b04d1b78e001fdb515b34539c67f..8f010ab27a6c99b97e7808218de908ce558b0fe7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -430,7 +430,7 @@ Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { auto on_false = tuple_select->operand(2); TF_RET_CHECK(pred->shape().element_type() == PRED); TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape())); - TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape())); + TF_RET_CHECK(tuple_select->shape().IsTuple()); llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select), GetIrArray(*pred, *tuple_select), GetBasePointer(*on_true), GetBasePointer(*on_false), @@ -492,8 +492,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); result = InsertValue(result, value.first, {0}); result = InsertValue(result, value.second, {1}); - } else { + } else if (ShapeUtil::ElementIsFloating(lhs_shape)) { result = FMul(lhs_value, rhs_value); + } else { + TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape)); + result = Mul(lhs_value, rhs_value); } target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_); return Status::OK(); @@ -583,9 +586,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { llvm::Value* accum_imag = Imag(accum, &b_); llvm::Value* imag_sum = FAdd(accum_imag, value.second); updated_accum = InsertValue(updated_accum, imag_sum, {1}); - } else { + } else if (ShapeUtil::ElementIsFloating(lhs_shape)) { llvm::Value* product = FMul(lhs_element, rhs_element); updated_accum = FAdd(accum, product); + } else { + TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape)); + llvm::Value* product = Mul(lhs_element, rhs_element); + updated_accum = Add(accum, product); } Store(updated_accum, accum_address); @@ -637,9 +644,9 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { return Unimplemented("Hit a case for fft that is not implemented on GPU."); } -Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { +Status IrEmitter::HandleAllReduce(HloInstruction* crs) { // TODO(b/33011107): Support cross replica sum on GPU. - return Unimplemented("CrossReplicaSum is not implemented on GPU."); + return Unimplemented("AllReduce is not implemented on GPU."); } Status IrEmitter::HandleParameter(HloInstruction* parameter) { @@ -647,8 +654,8 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { } Status IrEmitter::HandleReduce(HloInstruction* reduce) { - // TODO(b/112040122): Support variadic reduce. - if (!ShapeUtil::IsArray(reduce->shape())) { + // TODO(b/118332391): Support variadic reduce. + if (!reduce->shape().IsArray()) { return Unimplemented("Variadic reduce is not supported on GPU"); } auto arg = reduce->operand(0); @@ -783,7 +790,7 @@ StatusOr IrEmitter::ComputeNestedElement( std::vector IrEmitter::ConstructIrArrayForOutputs( const HloInstruction& hlo) { std::vector output_arrays; - if (ShapeUtil::IsTuple(hlo.shape())) { + if (hlo.shape().IsTuple()) { int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); output_arrays.reserve(num_outputs); for (int64 i = 0; i < num_outputs; ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 2da46c016935d0e927879bbfb0d05cfc4899d818..f380aee9d3c06a29b503c81c7bd3846dbccf6ce5 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -81,7 +81,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllReduce(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index fb040aff30d48bf5817946ce53d37bc6685941e4..0cc65ebb52737aa9bb8866eb07278a2319aa797b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -22,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "absl/algorithm/container.h" -#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" @@ -38,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" -#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" @@ -60,6 +59,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -89,6 +89,9 @@ namespace xla { namespace gpu { using llvm_ir::KernelMappingScheme; +using EmitElementFunction = + std::function; namespace { @@ -293,13 +296,12 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, auto shape_in_range = [&](const Shape& s) { bool in_range = true; - ShapeUtil::ForEachSubshape( - s, [&](const Shape& sub_shape, const ShapeIndex& /*index*/) { - if (ShapeUtil::IsArray(sub_shape) && - !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { - in_range = false; - } - }); + ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape, + const ShapeIndex& /*index*/) { + if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { + in_range = false; + } + }); return in_range; }; @@ -485,6 +487,41 @@ Status IrEmitterUnnested::HandleFft(HloInstruction* fft) { return Status::OK(); } +Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) { + auto has_fortran_layout = [](const Layout& layout) { + int n = layout.minor_to_major_size(); + return layout.minor_to_major(0) == n - 2 && + layout.minor_to_major(1) == n - 1; + }; + TF_RET_CHECK(has_fortran_layout(hlo->operand(0)->shape().layout())); + TF_RET_CHECK(has_fortran_layout(hlo->operand(1)->shape().layout())); + TF_RET_CHECK(has_fortran_layout(hlo->shape().layout())); + + std::vector> thunks; + + // Triangular solve is in-place on 'b', so copy 'b' to the output if they + // aren't the same buffer. + auto operand_buffer = GetAllocationSlice(*hlo->operand(1)); + auto destination_buffer = GetAllocationSlice(*hlo); + if (operand_buffer != destination_buffer) { + thunks.push_back(absl::make_unique( + /*source_address=*/operand_buffer, + /*destination_buffer=*/destination_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()), hlo)); + } + + thunks.push_back(BuildTriangularSolveThunk(hlo)); + + // Elide the sequential thunk if there's no copy. + if (thunks.size() == 1) { + AddThunkToThunkSequence(std::move(thunks[0])); + } else { + AddThunkToThunkSequence( + absl::make_unique(std::move(thunks), hlo)); + } + return Status::OK(); +} + Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) { @@ -543,96 +580,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // HandleFusion specializes reduction from a multi-dimensional array to // a 1D array. The specialized version requires a initializer thunk that // initializes the output array to the initial value of the reduce. - if (root->opcode() == HloOpcode::kReduce && - ShapeUtil::IsTuple(root->shape())) { - // TODO(b/112040122): Support variadic reduce. + if (root->opcode() == HloOpcode::kReduce && root->shape().IsTuple()) { + // TODO(b/118332391): Support variadic reduce. return Unimplemented("Variadic reduce is not supported on GPU"); } - VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); - std::vector> thunks; - absl::Span output_instructions = - root->opcode() == HloOpcode::kTuple - ? root->operands() - : absl::Span(&root, 1); - - // For multi-output fusion emit an initializer for each tuple element. - // Otherwise it's sufficient to just initialize the single output. - HloInstruction* first_reduce = nullptr; - for (int i = 0, e = output_instructions.size(); i != e; ++i) { - if (output_instructions[i]->opcode() == HloOpcode::kReduce) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr initializer_thunk, - BuildInitializerThunk(fusion, output_instructions[i] == root - ? ShapeIndex() - : ShapeIndex({i}))); - thunks.push_back(std::move(initializer_thunk)); - first_reduce = - first_reduce == nullptr ? output_instructions[i] : first_reduce; - } - } - CHECK(first_reduce != nullptr); - std::unique_ptr kernel_thunk = - BuildKernelThunk(fusion, /*implements_whole_instruction=*/false); - GpuElementalIrEmitter elemental_emitter( - hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, - GetNestedComputer()); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), - &elemental_emitter); - TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); - - // For multi-output fusion CHECK the constraints and feed all the - // reduces into a single loop code generator. Single-output reduce - // fusion is a special case of that. - InlinedVector input_gens; - InlinedVector init_value_gens; - std::vector> - extra_output_gens; - InlinedVector reducers; - InlinedVector reduce_output_shapes; - for (int i = 0, e = output_instructions.size(); i != e; ++i) { - const HloInstruction* inst = output_instructions[i]; - ShapeIndex output_shape_index; - if (root->opcode() == HloOpcode::kTuple) { - output_shape_index = {i}; - } - if (inst->opcode() == HloOpcode::kReduce) { - CHECK(IsReductionToVector(*inst)) - << "Only reductions to vector are supported"; - // Shapes, layouts and dimensions must be the same for all reduces - // inside of this fusion. - CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); - CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), - inst->operand(0)->shape())); - CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), - inst->operand(1)->shape())); - CHECK(first_reduce->dimensions() == inst->dimensions()); - input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); - init_value_gens.push_back( - fused_emitter.GetGenerator(inst->operand(1))); - reducers.push_back(inst->to_apply()); - reduce_output_shapes.push_back(std::move(output_shape_index)); - } else { - // For extra outputs we can relax shape equality to allow different - // types (with the same number of elements). Layouts still have to - // match. - CHECK(ShapeUtil::CompatibleIgnoringElementType( - first_reduce->operand(0)->shape(), inst->shape())); - CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), - inst->shape().layout())); - extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), - std::move(output_shape_index)); - } - } - const Shape& input_shape = first_reduce->operand(0)->shape(); - TF_CHECK_OK(EmitReductionToVector( - kernel_thunk.get(), first_reduce, input_shape, input_gens, - init_value_gens, first_reduce->dimensions(), reducers, - reduce_output_shapes, extra_output_gens)); - thunks.push_back(std::move(kernel_thunk)); - std::unique_ptr sequential_thunk = - absl::make_unique(std::move(thunks), fusion); - AddThunkToThunkSequence(std::move(sequential_thunk)); - return Status::OK(); + return EmitReductionToVector(fusion); } default: LOG(FATAL) << "Bad opcode for input fusion: " @@ -702,13 +654,12 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { } Status IrEmitterUnnested::EmitExtraOutputsForReduce( - const HloInstruction* reduce, const IrArray::Index& index, + const HloInstruction* unnested_hlo, const IrArray::Index& index, absl::Span> extra_output_gens) { for (int i = 0; i != extra_output_gens.size(); ++i) { - const HloInstruction* output = reduce->parent()->FusionInstruction(); llvm::Value* extra_output_address = - GetIrArray(*output, *output, extra_output_gens[i].second) + GetIrArray(*unnested_hlo, *unnested_hlo, extra_output_gens[i].second) .EmitArrayElementAddress(index, &b_, "extra_output_element_address"); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, @@ -718,984 +669,13 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( return Status::OK(); } -Status IrEmitterUnnested::EmitReductionToScalar( - KernelThunk* kernel_thunk, HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // Number of elements processed by a single thread. - constexpr int64 kTileSize = 16; - int64 num_elems = ShapeUtil::ElementsIn(input_shape); - - // Round up the number of tiles to a multiple of the warp size. This is - // necessary for correctness. We launch one thread per tile, and if the - // number of threads isn't a multiple of the number of the warp size, our - // shuffles will read from inactive threads, producing undefined values. - int64 num_tiles = - RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize); - - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {num_tiles}, {0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); - - llvm::Type* index_ty = - GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); - - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - // Check whether every thread will process a full tile's worth of elements - // without reading outside the bounds of the input. If this is true, we can - // skip some bounds checks in the final algorithm. - bool all_threads_in_bounds = num_tiles * kTileSize == num_elems; - - // __global__ void full_reduce_kernel() { - // x_in_tiles = threadIdx.x + blockIdx.x * blockDim.x; - // x = x_in_tiles * kTileSize; - // - // partial_result = init_value; - // if (all_threads_in_bounds || x + kTileSize <= num_elems) { - // for (i = 0; i < kTileSize; ++i) { - // partial_result = Reducer(partial_result, input[x + i]); - // } - // } else { - // for (i = 0; i < kTileSize; ++i) { - // if (x + i < num_elems) { - // partial_result = Reducer(partial_result, input[x + i]); - // } - // } - // } - // for (i = warpSize / 2; i > 0; i /= 2) { - // partial_result = Reducer(partial_result, - // __shfl_down(partial_result, i)); - // } - // if (lane_id == 0) { - // AtomicReducer(&output[y], partial_result); - // } - // } - // - // // Choose num_blocks and threads_per_block such that: - // // - // // num_blocks * threads_per_block = - // // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize), - // // - // // and threads_per_block is a multiple of warpSize. - // reduce_kernel // - auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { - const int num_reduces = reducers.size(); - llvm::Type* element_ir_type = - llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - std::vector partial_reduction_result_addresses; - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = - Alloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](IrArray::Index(index_ty))); - Store(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); - } - - llvm::Value* x_in_tiles = tile_index[0]; - x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); - - // Emit an inner for-loop that reduces the elements in the tile. - auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { - std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop( - "element_id_in_tile", index_typed_constant(0), - index_typed_constant(kTileSize), index_typed_constant(1), &b_); - - // Emit the body of the partial reduction loop. - llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &b_); - llvm::Value* x = - NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)), - tile_element_loop->GetIndVarValue()); - // Unless we know the tile is entirely in bounds, we have to emit a - // x-in-bounds check before reading from the input. - if (!tile_in_bounds) { - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_); - - // Emit code that reads the input element and accumulates it to - // the partial reduction result. - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); - } - - IrArray::Index input_index( - /*linear=*/x, input_shape, &b_); - llvm::Value* input_address = Alloca(element_ir_type); - for (int i = 0; i != num_reduces; ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - input_gens[i](input_index)); - Store(input_ir_value, input_address); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], input_address}, - partial_reduction_result_addresses[i])); - } - return EmitExtraOutputsForReduce(reduce, input_index, extra_output_gens); - }; - - // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's - // immediately beyond the tile. - llvm::Value* x_end = - NSWAdd(index_typed_constant(kTileSize), - NSWMul(x_in_tiles, index_typed_constant(kTileSize))); - // The tile is entirely in bound if all_threads_in_bounds or - // x_end <= num_elems. - llvm::Value* tile_in_bounds = - Or(ICmpULE(x_end, index_typed_constant(num_elems)), - b_.getInt1(all_threads_in_bounds)); - llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); - - // After the if-then-else statement on tile_in_bounds, emit calls to - // shfl_down that accumulate the partial reduction results of all threads - // from the warp. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &b_); - int bit_width = llvm_ir::GetSizeInBits(element_ir_type); - // bitcast cannot be applied to aggregate types (even packed ones), so we - // instead bitcast addresses of load/store to intN* of the same bit-width. - llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() - ? b_.getIntNTy(bit_width) - : element_ir_type; - for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; - shuffle_distance /= 2) { - llvm::Value* result_from_other_lane = - Alloca(element_ir_type, nullptr, "result_from_other_lane"); - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = - Load(BitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); - CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) - << "Requires block size a multiple of the warp size, otherwise we " - "will read undefined elements."; - Store(EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], result_from_other_lane}, - partial_reduction_result_addresses[i])); - } - } - - const HloInstruction* output = - reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - - // Emit an atomic operation that accumulates the partial reduction result of - // lane 0 (which holds the partially accumulated result for its warp) to the - // output element. - llvm::Value* lane_id = - URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); - llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); - - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - IrArray::Index( - /*linear=*/b_.getInt64(0), - ShapeUtil::GetSubshape(output->shape(), - reduce_output_shapes[i]), - &b_), - &b_, "output_element_address"); - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, partial_reduction_result_addresses[i])); - } - return Status::OK(); - }; - - // Emit a parallel loop that iterates through all input tiles, one per thread. - UpdateLaunchDimensions(launch_dimensions, kernel_thunk, - ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &b_) - .EmitLoop(IrName(reduce), index_ty); -} - -Status IrEmitterUnnested::EmitColumnReduction( - KernelThunk* kernel_thunk, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // Divide the input matrix into tiles of size KxL. For example, when the - // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like - // - // 0123 - // 0123 - // 4567 - // 4567 // Numbers indicate tile IDs. - // - // Each tile is first partially reduced to a scalar by a thread, and then the - // scalar is accumulated to the output vector using atomic operations. - // - // We choose 128 as the tile size based on empirical evidence. It's big enough - // to reduce the amount of atomic adds in the end, maximizing the memory - // bandwidth. A tile width of 2 allows for high memory bandwidth utilization - // on 16b input data. - constexpr int64 kTileHeight = 128; - constexpr int64 kTileWidth = 2; - - // If the height is not a multiple of kTileHeight, we pad the bottom of the - // input matrix. - const int64 height_in_tiles = CeilOfRatio(height, kTileHeight); - // If width is not a multiple of kTileWidth the rightmost thread will process - // fewer input elements. - const int64 width_in_tiles = CeilOfRatio(width, kTileWidth); - Shape tiled_input_shape = - ShapeUtil::MakeShapeWithLayout(reduce->shape().element_type(), - {height_in_tiles, width_in_tiles}, {1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); - - // TODO(b/110211620): Convert to use i32 index_type when it is possible. - llvm::Type* index_ty = b_.getInt64Ty(); - - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; - // linear_index < height_in_tiles * width_in_tiles; - // linear_index += blockDim.x * gridDim.x) { - // y_in_tiles = linear_index / width_in_tiles; - // x_in_tiles = linear_index % width_in_tiles; - // - // partial_results[kTileWidth] = init_values; - // tile_in_y_bounds = height % kTileHeight == 0 || - // y_in_tiles * kTileHeight + kTileHeight <= height; - // tile_in_x_bounds = width % kTileWidth == 0 || - // x_in_tiles * kTileWidth + kTileWidth <= width; - // // The implementation handles y and x bound checks separately. - // if (tile_in_y_bounds && tile_in_x_bounds) { - // for (y_offset : range(kTileHeight)) { - // y = y_in_tiles * kTileHeight + y_offset; - // for (x_offset : range(kTileWidth)) { - // x = x_in_tiles * kTileWidth + x_offset; - // partial_result = Reducer(partial_result[x_offset], input[y][x]); - // } - // } - // } else { - // for (y_offset : range(kTileHeight)) { - // y = y_in_tiles * kTileHeight + y_offset; - // for (y_offset : range(kTileHeight)) { - // x = x_in_tiles * kTileWidth + x_offset; - // if (y < height && x < width) { - // partial_result = Reducer(partial_result, input[y][x]); - // } - // } - // } - // } - // for (x_offset : range(kTileWidth)) { - // AtomicReducer(&output[x + x_offset], partial_result[x_offset]); - // } - // } - auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { - const int num_reduces = reducers.size(); - // Emit the loop body that reduces one tile. - llvm::Type* element_ir_type = - llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - std::vector partial_reduction_result_addresses; - for (int i = 0; i != num_reduces; ++i) { - for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* partial_reduction_result_address = - Alloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + - llvm::Twine(i * kTileWidth + x_offset)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](IrArray::Index(index_ty))); - Store(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); - } - } - - // Emit an inner for-loop that partially reduces the elements in the given - // tile. - llvm::Value* y_in_tiles = tile_index[0]; - llvm::Value* x_in_tiles = tile_index[1]; - - y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty); - x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); - - auto emit_tile_element_loop = [=](bool tile_in_y_bounds, - bool tile_in_x_bounds) -> Status { - std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop( - "element_id_in_tile", index_typed_constant(0), - index_typed_constant(kTileHeight), index_typed_constant(1), &b_); - - // Emit the body of the partial reduction loop. - llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &b_); - llvm::Value* y = - NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)), - tile_element_loop->GetIndVarValue()); - - // Unless we know that y is in bounds, we have to emit a check before - // reading from the input. - if (!tile_in_y_bounds) { - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_); - - // Emit code that reads the input element and accumulates it to - // the partial reduction result. - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); - } - for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = - NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); - // Unless we know that x is in bounds, we have to emit a check before - // reading from the input. - if (!tile_in_x_bounds) { - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); - } - llvm::Value* input_address = Alloca(element_ir_type); - // {y,x} is an index to input_matrix_shape [height,width]. We need to - // convert that to an index to input_shape (the shape of the operand of - // "reduce"). This conversion is composed of a transposition from - // input_shape to normalized_input_shape and a reshape from - // normalized_input_shape to input_matrix_shape. - const Shape normalized_input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); - const std::vector transpose_dimension_mapping( - input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); - - const Shape input_matrix_shape = - ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), - {height, width}); - const IrArray::Index input_matrix_index({y, x}, input_matrix_shape, - &b_); - const IrArray::Index input_index = - input_matrix_index - .SourceIndexOfReshape(input_matrix_shape, - normalized_input_shape, &b_) - .SourceIndexOfTranspose(normalized_input_shape, input_shape, - transpose_dimension_mapping, &b_); - for (int i = 0; i != num_reduces; ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - input_gens[i](input_index)); - Store(input_ir_value, input_address); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i * kTileWidth + x_offset], - input_address}, - partial_reduction_result_addresses[i * kTileWidth + x_offset])); - TF_RETURN_IF_ERROR(EmitExtraOutputsForReduce(reduce, input_index, - extra_output_gens)); - } - } - return Status::OK(); - }; - - // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location - // that's immediately beyond the tile. - llvm::Value* y_end = - NSWAdd(index_typed_constant(kTileHeight), - NSWMul(y_in_tiles, index_typed_constant(kTileHeight))); - // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location - // that's immediately beyond the tile. - llvm::Value* x_end = - NSWAdd(index_typed_constant(kTileWidth), - NSWMul(x_in_tiles, index_typed_constant(kTileWidth))); - llvm::Value* tile_in_y_bounds = - Or(ICmpULE(y_end, index_typed_constant(height)), - b_.getInt1(height % kTileHeight == 0)); - llvm::Value* tile_in_x_bounds = - Or(ICmpULE(x_end, index_typed_constant(width)), - b_.getInt1(width % kTileWidth == 0)); - // The tile is in y bounds if "height" is a multiple of kTileHeight or - // y_end <= height. - llvm_ir::LlvmIfData if_tile_in_y_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_y_bounds, "tile_in_y_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block, &b_); - // The tile is in x bounds if "width" is a multiple of kTileWidth or - // x_end <= width. - llvm_ir::LlvmIfData if_tile_in_x_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, - /*tile_in_x_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, - /*tile_in_x_bounds=*/false)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block, &b_); - if_tile_in_x_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, - /*tile_in_x_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, - /*tile_in_x_bounds=*/false)); - - // After the nested if-then-else statement on tile_in_y_bounds and - // tile_in_x_bounds, emit atomic operations to accumulate the partial - // reduction result to the output element. - llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block, &b_); - const HloInstruction* output = - reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - for (int i = 0; i != num_reduces; ++i) { - for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = - NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - IrArray::Index( - x, - ShapeUtil::GetSubshape(output->shape(), - reduce_output_shapes[i]), - &b_), - &b_, "output_element_address"); - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, - partial_reduction_result_addresses[i * kTileWidth + x_offset])); - } - } - return Status::OK(); - }; - - // Emit a parallel loop that iterate through all input tiles. - UpdateLaunchDimensions(launch_dimensions, kernel_thunk, - ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &b_) - .EmitLoop(IrName(reduce), index_ty); -} - -static std::pair ComputeKernelMappingSchemeForReduction( - int64 depth, int64 width, int64 kWarpSize) { - constexpr int64 kTargetNumElementsPerThread = 64; - int64 x_tile_size = kTargetNumElementsPerThread; - int64 z_tile_size = 1; - - // Only tile along the x dimension with tile size kTargetNumElementsPerThread - // if doing so doesn't require a slow version of loop with bound check on each - // dimension. A more sophisticated heuristics is to enable tile along the - // x dimension with tile size kTargetNumElementsPerThread when either width is - // a factor of (kWarpSize * kTargetNumElementsPerThread) or width is big - // enough so that only a small fraction of the threads execute the slow - // version of loop with bound check. - if (width % (kWarpSize * kTargetNumElementsPerThread) != 0) { - x_tile_size = 8; - z_tile_size = 8; - while (depth % z_tile_size != 0) { - z_tile_size -= 1; - } - } - - return std::pair(x_tile_size, z_tile_size); -} - -Status IrEmitterUnnested::EmitRowReduction( - KernelThunk* kernel_thunk, int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // A naive algorithm is: - // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. - // 2. Partially reduces each tile to a scalar using one thread. - // 3. Accumulates that scalar to the output vector using atomic operations. - // - // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; - // linear_index < depth * height * width_in_tiles; - // linear_index += blockDim.x * gridDim.x) { - // int x_in_tiles = linear_index % width_in_tiles; - // int y = linear_index / width_in_tiles % height; - // int z = linear_index / (height * width_in_tiles); - // float partial_result = 0; - // for (element_id_in_tile : range(x_tile_size)) { - // int x = x_in_tiles * x_tile_size + element_id_in_tile; - // if (x < width) - // partial_result = reducer(partial_result, input[z][y][x]); - // } - // AtomicReducer(&output[y], partial_result); - // } - // - // Four optimizations are performed. - // - // 1. To coalesce global memory accesses, dilate the tile with a factor of 32 - // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead - // of making each tile consecutive, we let make tile 0 column - // [0,32,64,...,224], tile 1 column [1,33,65,...,225], and so on. This ensures - // that threads in a warp access consecutive memory in one iteration (i.e. - // coalesced). In the above example, the warp that contains thread 0-31 - // accesses column 0-31 in the first iteration, and 32-63 in the second - // iteration, and so on. - // - // 2. Partially accumulate partial reduced results computed by threads in the - // same warp using shfl_down. Using shfl_down is faster than directly using - // atomic operations because shfl_down transfers the data between threads - // using shared memory and threads in the same warp run in lock step (thus no - // extra synchronization needed). See - // https://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/ - // for details. The downside is, to produce correct results when using - // shfl_down, we need to guarantee threads in the same warp work on input - // elements with the same y, so the number of tiles in each row must be a - // multiple of 32. - // - // 3. Specialize the case that the entire tile is in bounds. When that is - // true, we don't need to emit "if(x 0; shuffle_distance /= 2) - // partial_result = Reducer( - // partial_result, - // __shfl_down_sync(CUDA_WARP_ALL, partial_result, shuffle_distance)); - // if (lane_id == 0) - // AtomicReducer(&output[y], partial_result); - // } - // - - int64 x_tile_size; - int64 z_tile_size; - std::tie(x_tile_size, z_tile_size) = - ComputeKernelMappingSchemeForReduction(depth, width, kWarpSize); - - // Round the width in tiles up to the nearest multiple of kWarpSize, so that - // the use of shfl_down is valid. - const int64 width_in_tiles = - RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize); - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), - {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); - llvm::Type* index_ty = - GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); - - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - auto loop_body_emitter = [=](const IrArray::Index& tile_index) { - const int num_reduces = reducers.size(); - llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( - input_shape.element_type(), ir_emitter_context_->llvm_module()); - std::vector partial_reduction_result_addresses; - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = - Alloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](IrArray::Index(index_ty))); - Store(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); - } - - llvm::Value* z_tile = tile_index[0]; - llvm::Value* y = tile_index[1]; - llvm::Value* x_tile = tile_index[2]; - - x_tile = ZExtOrTrunc(x_tile, index_ty); - - llvm::Value* warp_id = - UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); - llvm::Value* lane_id = - URem(x_tile, index_typed_constant(kWarpSize), "lane_id"); - - // The x-location of the last element in this z-x-tile. - // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); - llvm::Value* last_x = NSWAdd( - lane_id, - NSWMul(index_typed_constant(kWarpSize), - NSWAdd(index_typed_constant(x_tile_size - 1), - NSWMul(warp_id, index_typed_constant(x_tile_size))))); - - KernelSupportLibrary ksl( - &b_, - /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll, - /*prevent_vectorization=*/false); - - // Emit a for-loop that partially reduces the elements in the given - // z-x-tile. - auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, - int64 x_tile_loop_bound) -> Status { - auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { - llvm::Value* z = - NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile)); - TF_RETURN_IF_ERROR(ksl.For( - "x_tile", - /*start=*/index_typed_constant(0), - /*end=*/index_typed_constant(x_tile_loop_bound), - /*step=*/1, [&](llvm::Value* x_indvar) -> Status { - // x = lane_id + - // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); - llvm::Value* x = NSWAdd( - lane_id, - NSWMul(index_typed_constant(kWarpSize), - NSWAdd(x_indvar, - NSWMul(warp_id, llvm::ConstantInt::get( - index_ty, x_tile_size))))); - - // Unless we know the x-tile is entirely in bounds, we have to - // emit a x-in-bounds check before reading from the input. - if (!x_tile_in_bounds) { - llvm_ir::LlvmIfData if_x_in_bounds_data = - llvm_ir::EmitIfThenElse( - ICmpULT(x, index_typed_constant(width)), "x_in_bounds", - &b_); - // Points b_ to the then-block. - llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, - &b_); - } - - // Emit code that reads the input element and accumulates it - // to the partial reduction result. - llvm::Value* input_address = Alloca(element_ir_type); - { - // {z,y,x} is an index to input_3d_tensor_shape - // [depth,height,width]. We need to convert that to an index - // to input_shape (the shape of the operand of "reduce"). - // This conversion is composed of a transposition from - // input_shape to normalized_input_shape and a reshape from - // normalized_input_shape to input_3d_tensor_shape. - const Shape normalized_input_shape = ShapeUtil:: - MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - auto input_shape_min2maj = - LayoutUtil::MinorToMajor(input_shape); - const std::vector transpose_dimension_mapping( - input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); - const Shape input_3d_tensor_shape = - ShapeUtil::MakeShapeWithDescendingLayout( - input_shape.element_type(), {depth, height, width}); - const IrArray::Index input_3d_tensor_index( - {z, y, x}, input_3d_tensor_shape, &b_); - const IrArray::Index input_index = - input_3d_tensor_index - .SourceIndexOfReshape(input_3d_tensor_shape, - normalized_input_shape, &b_) - .SourceIndexOfTranspose( - normalized_input_shape, input_shape, - transpose_dimension_mapping, &b_); - - for (int i = 0; i != num_reduces; ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - input_gens[i](input_index)); - Store(input_ir_value, input_address); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], input_address}, - partial_reduction_result_addresses[i])); - } - return EmitExtraOutputsForReduce(reduce, input_index, - extra_output_gens); - } - })); - return Status::OK(); - }; - - return ksl.For("z_tile", - /*start=*/index_typed_constant(0), - /*end=*/index_typed_constant(z_tile_size), - /*step=*/1, emit_z_tile_element_loop); - }; - - llvm::Value* tile_in_bounds = - Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), - ICmpULT(last_x, index_typed_constant(width))); - - TF_RETURN_IF_ERROR( - ksl.If(tile_in_bounds, - /*true_block_generator=*/ - [&]() -> Status { - return emit_z_x_tile_element_loop(/*x_tile_in_bounds=*/true, - x_tile_size); - }, - /*false_block_generator=*/ - [&]() -> Status { - return emit_z_x_tile_element_loop( - /*x_tile_in_bounds=*/false, - CeilOfRatio(width % (x_tile_size * kWarpSize), kWarpSize)); - })); - - // After accumulating the elements of the z_x_tile, emit calls to - // shfl_down that accumulate the partial reduction results of all - // threads in a warp. - int bit_width = llvm_ir::GetSizeInBits(element_ir_type); - // bitcast cannot be applied to aggregate types (even packed ones), so we - // instead bitcast addresses of load/store to intN* of the same bit-width. - llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() - ? b_.getIntNTy(bit_width) - : element_ir_type; - for (int shuffle_distance = 16; shuffle_distance >= 1; - shuffle_distance /= 2) { - llvm::Value* result_from_other_lane = - Alloca(element_ir_type, nullptr, "result_from_other_lane"); - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = - Load(BitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); - CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) - << "Requires block size a multiple of the warp size, otherwise we " - "will read undefined elements."; - Store(EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], result_from_other_lane}, - partial_reduction_result_addresses[i])); - } - } - - const HloInstruction* output = - reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - - // Emit an atomic operation that accumulates the partial reduction result of - // lane 0 (which holds the partially accumulated result for its warp) to the - // output element. - llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - IrArray::Index(y, - ShapeUtil::GetSubshape( - output->shape(), reduce_output_shapes[i]), - &b_), - &b_, "output_element_address"); - // We don't need to emit atomic operations if there is only one tile of - // results. 'depth' is the z dimension, 'width' is the x dimension. - if (z_tile_size >= depth && x_tile_size >= width) { - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {output_address, partial_reduction_result_addresses[i]}, - output_address)); - } else { - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, - partial_reduction_result_addresses[i])); - } - } - return Status::OK(); - }; - - // Emit a parallel loop that iterates through every input tiles. - UpdateLaunchDimensions(launch_dimensions, kernel_thunk, - ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &b_) - .EmitLoop(IrName(reduce), index_ty); -} - -// Figures out whether `reduce` is a row or column reduction, and which -// dimensions to reduce, and calls either `EmitRowReduction` or -// `EmitColumnReduction` as appropriate. -// Prerequisite: all the dimensions to keep are contiguous in the input layout -// and, if `reduce` is fused, the fused subgraph is pure -// elementwise. -Status IrEmitterUnnested::EmitReductionToVector( - KernelThunk* kernel_thunk, HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span dimensions_to_reduce, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // This emission requires "reduce" to have an input layout. It is either set - // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for - // a fused kReduce). - CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " - "doesn't set the input layout of " - << reduce->ToString(); - - // Specialize multi-dimensional-array-to-vector reduction. - std::vector input_dims_to_keep; - for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); - ++input_dim) { - if (std::find(dimensions_to_reduce.begin(), dimensions_to_reduce.end(), - input_dim) == dimensions_to_reduce.end()) { - input_dims_to_keep.push_back(input_dim); - } - } - - // Sort the dimensions to keep from minor to major, to facilitate checking - // whether another dimension is major or minor of them. - std::sort(input_dims_to_keep.begin(), input_dims_to_keep.end(), - [&input_shape](int64 dim_a, int64 dim_b) { - return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - dim_a) < - PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - dim_b); - }); - // Now, if output rank is at least 1, `input_dims_to_keep.front()` is - // minormost and `input_dims_to_keep.back()` is majormost. - - // If the dimensions to keep are minormost, emit a column reduction. As all - // the dimensions to keep are contiguous, by prerequisite of - // `EmitReductionToVector`, we only need to check whether the minormost - // dimension of the input is to keep. - if (ShapeUtil::IsEffectiveScalar(reduce->shape())) { - return EmitReductionToScalar(kernel_thunk, reduce, input_shape, input_gens, - init_value_gens, reducers, - reduce_output_shapes, extra_output_gens); - } else if (input_dims_to_keep.front() == - LayoutUtil::Minor(input_shape.layout(), 0)) { - // Column reduction. Treat the result of "input" as a matrix whose width - // is the most minor dimension and height the product of other dimensions, - // and treat "reduce" as a column reduction of the input matrix. - const int64 width = ShapeUtil::ElementsIn(reduce->shape()); - // "width" can be zero, so don't do - // height = ShapeUtil::ElementsIn(input_shape) / width; - int64 height = 1; - for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); - ++input_dim) { - if (!std::count(input_dims_to_keep.begin(), input_dims_to_keep.end(), - input_dim)) { - height *= input_shape.dimensions(input_dim); - } - } - return EmitColumnReduction(kernel_thunk, height, width, reduce, input_shape, - input_gens, init_value_gens, reducers, - reduce_output_shapes, extra_output_gens); - } else { - // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a - // 3D tensor. The size of dimension 1 (the height) is the size of the - // dimension to keep, the size of dimension 0 (the depth) is the product - // of dimensions that are more major than the dimension to keep, and the - // size of dimension 2 (the width) is the product of more minor - // dimensions. - int64 depth = 1; - int64 width = 1; - for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); - ++input_dim) { - if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dim) > - PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dims_to_keep.back())) { - depth *= input_shape.dimensions(input_dim); - } else if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dim) < - PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dims_to_keep.front())) { - width *= input_shape.dimensions(input_dim); - } - } - const int64 height = ShapeUtil::ElementsIn(reduce->shape()); - return EmitRowReduction(kernel_thunk, depth, height, width, reduce, - input_shape, input_gens, init_value_gens, reducers, - reduce_output_shapes, extra_output_gens); - } -} - Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { - // TODO(b/112040122): Support multi-output reduce. - if (!ShapeUtil::IsArray(reduce->shape())) { + // TODO(b/118332391): Support multi-output reduce. + if (!reduce->shape().IsArray()) { return Unimplemented("Multi-output reduce is not supported on GPU"); } - auto input = reduce->operand(0); - auto init_value = reduce->operand(1); - absl::Span dimensions_to_reduce(reduce->dimensions()); - HloComputation* reducer = reduce->to_apply(); - // HandleReduce specializes reduction from a multi-dimensional array to a 1D - // array. The specialized version requires an initializer thunk that - // initializes the output array to the initial value of the reduce. if (IsReductionToVector(*reduce)) { - TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, - BuildInitializerThunk(reduce)); - std::vector> thunks; - thunks.push_back(std::move(initializer_thunk)); - std::unique_ptr kernel_thunk = - BuildKernelThunk(reduce, /*implements_whole_instruction=*/false); - - TF_CHECK_OK(EmitReductionToVector( - kernel_thunk.get(), reduce, input->shape(), - {[&](const IrArray::Index& index) { - return GetIrArray(*input, *reduce).EmitReadArrayElement(index, &b_); - }}, - {[&](const IrArray::Index& index) { - return GetIrArray(*init_value, *reduce) - .EmitReadArrayElement(index, &b_); - }}, - dimensions_to_reduce, {reducer}, {{}}, {})); - - thunks.push_back(std::move(kernel_thunk)); - - std::unique_ptr sequential_thunk = - absl::make_unique(std::move(thunks), reduce); - AddThunkToThunkSequence(std::move(sequential_thunk)); - return Status::OK(); + return EmitReductionToVector(reduce); } return IrEmitter::HandleReduce(reduce); @@ -1755,8 +735,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( const auto* source = select_and_scatter->operand(1); const Window& window = select_and_scatter->window(); PrimitiveType operand_element_type = operand->shape().element_type(); - const int64 rank = ShapeUtil::Rank(operand->shape()); - CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); + const int64 rank = operand->shape().rank(); + CHECK_EQ(rank, source->shape().rank()); CHECK_EQ(rank, window.dimensions_size()); TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, @@ -1820,7 +800,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_, index_type); - std::vector window_size; + DimensionVector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); CHECK_GT(dim.size(), 0); @@ -2014,18 +994,18 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { BuildKernelThunk(scatter, /*implements_whole_instruction=*/thunks.empty())); - TF_RETURN_IF_ERROR( - EmitScatter(thunks.back().get(), scatter, - /*scatter_indices_gen=*/ - [=](const IrArray::Index& index) { - return GetIrArray(*scatter_indices, *scatter) - .EmitReadArrayElement(index, &b_, "scatter_index"); - }, - /*updates_gen=*/ - [=](const IrArray::Index& index) { - return GetIrArray(*updates, *scatter) - .EmitReadArrayElement(index, &b_, "update"); - })); + TF_RETURN_IF_ERROR(EmitScatter( + thunks.back().get(), scatter, + /*scatter_indices_gen=*/ + [=](const IrArray::Index& index) { + return GetIrArray(*scatter_indices, *scatter) + .EmitReadArrayElement(index, &b_, "scatter_index"); + }, + /*updates_gen=*/ + [=](const IrArray::Index& index) { + return GetIrArray(*updates, *scatter) + .EmitReadArrayElement(index, &b_, "update"); + })); // Elide the sequential thunk if there's no copy. if (thunks.size() == 1) { @@ -2072,7 +1052,7 @@ Status IrEmitterUnnested::EmitScatter( int64 raw_window_multidim_idx = 0; std::vector input_window_multidim; std::vector input_window_bounds; - for (int64 i = 0, e = ShapeUtil::Rank(operand->shape()); i != e; ++i) { + for (int64 i = 0, e = operand->shape().rank(); i != e; ++i) { if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { input_window_bounds.push_back(1); // Trivial dimension. input_window_multidim.push_back(index.GetConstantWithIndexType(0)); @@ -2084,12 +1064,11 @@ Status IrEmitterUnnested::EmitScatter( ++raw_window_multidim_idx; } } - DCHECK_EQ(input_window_multidim.size(), ShapeUtil::Rank(operand->shape())); + DCHECK_EQ(input_window_multidim.size(), operand->shape().rank()); // Insert a 1 dimension at the end if index_vector_dim requests one. Shape scatter_indices_shape = scatter_indices->shape(); - if (dim_numbers.index_vector_dim() == - ShapeUtil::Rank(scatter_indices_shape)) { + if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) { scatter_indices_shape.add_dimensions(1); scatter_indices_shape.mutable_layout()->add_minor_to_major( dim_numbers.index_vector_dim()); @@ -2174,17 +1153,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { std::vector> thunks; Shape keys_shape = sort->operand(0)->shape(); int64 dimension_to_sort = sort->dimensions(0); - // In case there is a 'values' parameter that is a iota, we take note and use - // it later to ensure a stable sort. Otherwise, we don't guarantee a stable - // sort. - int64 iota_values_parameter_index = -1; for (int64 i = 0; i < sort->operand_count(); ++i) { - if (i > 0 && sort->operand(i)->opcode() == HloOpcode::kIota && - ShapeUtil::ElementIsIntegral(sort->operand(i)->shape()) && - Cast(sort->operand(i))->iota_dimension() == - dimension_to_sort) { - iota_values_parameter_index = i; - } ShapeIndex shape_index = sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); // We assume that the layout of all involved operands and outputs is the @@ -2297,25 +1266,23 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { : standard_launch_dimensions; UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), ir_emitter_context_->llvm_module()); - IrArray keys_array; std::vector values_arrays; - values_arrays.reserve(sort->operand_count() - 1); + values_arrays.reserve(sort->operand_count()); for (int64 i = 0; i < sort->operand_count(); ++i) { ShapeIndex shape_index = sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); - if (i == 0) { - keys_array = GetIrArray(*sort, *sort, shape_index); - } else { - values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); - } + values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); } return llvm_ir::EmitSortInPlace( - dimension_to_sort, keys_array, values_arrays, - iota_values_parameter_index, IrName(sort), xor_masks, &b_, + dimension_to_sort, values_arrays, IrName(sort), xor_masks, &b_, launch_dimensions, xor_masks.size() > 1 ? num_iterations_in_sort_dim : standard_num_iterations_in_sort_dim, - kTileSize); + kTileSize, + [&](absl::Span operands, llvm::Value* output) { + return EmitCallToNestedComputation(*sort->to_apply(), operands, + output); + }); }; std::vector xor_masks; for (int64 stage = 0; stage < num_stages; ++stage) { @@ -2352,11 +1319,11 @@ Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { return IrEmitter::HandleTupleSelect(tuple_select); } -Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { +Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { if (hlo_module_config_.replica_count() != 1) { // TODO(b/33011107): Support nontrivial cross replica sum on GPU. return Unimplemented( - "CrossReplicaSum with >1 replica is not implemented on GPU."); + "AllReduce with >1 replica is not implemented on GPU."); } // CRS with one operand and one replica is simply the identity function. @@ -2367,8 +1334,8 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { // HloModuleConfig::num_replicas changes between when the module is compiled // and when it's run. if (crs->operand_count() == 1) { - CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) - << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + CHECK(crs->operand(0)->shape().IsArray()) + << "Operands to all-reduce must be arrays: " << crs->ToString(); AddThunkToThunkSequence(absl::make_unique( /*source_address=*/GetAllocationSlice(*crs->operand(0)), /*destination_buffer=*/GetAllocationSlice(*crs), @@ -2566,10 +1533,10 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( return !allocation->is_constant(); }); - std::sort(non_constant_buffers.begin(), non_constant_buffers.end(), - [](const BufferAllocation* a, const BufferAllocation* b) { - return a->index() < b->index(); - }); + absl::c_sort(non_constant_buffers, + [](const BufferAllocation* a, const BufferAllocation* b) { + return a->index() < b->index(); + }); llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers); @@ -2814,6 +1781,29 @@ std::unique_ptr IrEmitterUnnested::BuildFftThunk( /*output_shape=*/inst->shape(), inst); } +std::unique_ptr IrEmitterUnnested::BuildTriangularSolveThunk( + const HloInstruction* inst) { + const HloInstruction* a = inst->operand(0); + const HloInstruction* b = inst->operand(1); + int64 m = b->shape().dimensions(b->shape().rank() - 2); + int64 n = b->shape().dimensions(b->shape().rank() - 1); + int64 batch_size = std::accumulate( + b->shape().dimensions().begin(), b->shape().dimensions().end() - 2, + int64{1}, [](int64 a, int64 b) { return a * b; }); + int64 elem_size = + ShapeUtil::ByteSizeOfPrimitiveType(inst->shape().element_type()); + int64 a_batch_stride = inst->triangular_solve_options().left_side() + ? m * m * elem_size + : n * n * elem_size; + int64 b_batch_stride = m * n * elem_size; + return absl::make_unique( + inst->triangular_solve_options(), + /*a_input_buffer=*/GetAllocationSlice(*a), + /*b_input_buffer=*/GetAllocationSlice(*inst), + inst->shape().element_type(), batch_size, m, n, a_batch_stride, + b_batch_stride, inst); +} + StatusOr> IrEmitterUnnested::BuildInitializerThunk( HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); @@ -3121,11 +2111,9 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( // pressure, since we touch threadIdx.x and blockIdx.x at the beginning of the // kernel *anyway*. std::vector output_arrays = ConstructIrArrayForOutputs(hlo); - TF_RETURN_IF_ERROR( - KernelSupportLibrary(&b_).If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); - return Status::OK(); - })); + KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + }); // For multioutput fusion, we need to emit each operand and the root. TF_RETURN_IF_ERROR( @@ -3139,12 +2127,36 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( return Status::OK(); } +namespace { + +// Returns true if the fusion contains any instruction that is likely +// translated to complex LLVM IR, such as loops, and prevent vectorization. +bool MayPreventVectorization(const HloInstruction& fusion_hlo) { + CHECK_EQ(fusion_hlo.opcode(), HloOpcode::kFusion); + return absl::c_any_of( + fusion_hlo.fused_instructions_computation()->instructions(), + [&](const HloInstruction* instr) { + switch (instr->opcode()) { + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSort: + case HloOpcode::kDot: + return true; + default: + return false; + } + }); +} + +} // namespace + Status IrEmitterUnnested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { int unroll_factor = 1; // Unfused elementwise operations are usually memory bound, unroll them. - if (hlo.IsElementwise() || hlo.opcode() == HloOpcode::kFusion) { + if (hlo.IsElementwise() || + (hlo.opcode() == HloOpcode::kFusion && !MayPreventVectorization(hlo))) { unroll_factor = ComputeMaxUnrollFactor(&hlo); } @@ -3167,7 +2179,6 @@ std::vector IrEmitterUnnested::ConstructIrArrayForInputs( return param_arrays; } - int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( const HloInstruction& hlo, const std::vector& param_arrays, const std::vector& param_buffers, @@ -3195,54 +2206,90 @@ int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( namespace { -void EmitFullTile(const KernelMappingScheme* mapping_scheme, - const IrArray::Index& tile_origin_index, - llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, - llvm::Type* index_ty, - const std::function& emit_elem_function) { +std::tuple GetStartOffsetAndStepForX( + int64 tile_size_x, int64 num_threads_x, + const KernelMappingScheme* mapping_scheme, llvm::IRBuilder<>* builder, + llvm::Value* x, llvm::Type* index_ty) { + llvm::Value* start_offset_x; + int64 step_x; + if (mapping_scheme->DilatedX()) { + start_offset_x = x; + step_x = num_threads_x; + } else { + start_offset_x = builder->CreateMul( + x, llvm::ConstantInt::get(index_ty, tile_size_x / num_threads_x)); + step_x = 1; + } + return std::make_tuple(start_offset_x, step_x); +} + +void EmitFullElementalTile(const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, + const string& loop_name, KernelSupportLibrary* ksl, + llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Type* index_ty, + const EmitElementFunction& emit_elem_function) { int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); - for (int64 i = 0; i < tile_size_y; i += num_threads_y) { - IrArray::Index source_idx_y = - tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, i), - KernelMappingScheme::DimY, builder); - llvm::Value* y_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, i), y); - for (int64 j = 0; j < tile_size_x; j += num_threads_x) { - IrArray::Index source_idx = - source_idx_y.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), - KernelMappingScheme::DimX, builder); - llvm::Value* x_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); - emit_elem_function(source_idx, y_loc, x_loc); - } - } -} -void EmitPartialTile( - const KernelMappingScheme* mapping_scheme, - const IrArray::Index& tile_origin_index, const string& loop_name, - KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, - llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, - llvm::Type* index_ty, - const std::function& emit_elem_function) { + llvm::Value* start_offset_x; + int64 step_x; + std::tie(start_offset_x, step_x) = GetStartOffsetAndStepForX( + tile_size_x, num_threads_x, mapping_scheme, builder, x, index_ty); + IrArray::Index source_idx = + tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, builder) + .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, builder); + ksl->For(loop_name + "_y", /*start=*/llvm::ConstantInt::get(index_ty, 0), + /*end=*/llvm::ConstantInt::get(index_ty, tile_size_y), + /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), + [&](llvm::Value* y_indvar) { + IrArray::Index source_idx_y = source_idx.AddOffsetToDim( + y_indvar, KernelMappingScheme::DimY, builder); + llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); + + for (int64 j = 0; j < tile_size_x / num_threads_x; j++) { + IrArray::Index source_idx_y_x = source_idx_y.AddOffsetToDim( + llvm::ConstantInt::get(index_ty, j * step_x), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = builder->CreateAdd( + llvm::ConstantInt::get(index_ty, j * step_x), + start_offset_x); + emit_elem_function(source_idx_y_x, y_loc, x_loc, j); + } + }); +} + +void EmitPartialElementalTile(const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, + const string& loop_name, + KernelSupportLibrary* ksl, + llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Value* tile_height, + llvm::Value* tile_width, llvm::Type* index_ty, + const EmitElementFunction& emit_elem_function) { int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); - for (int64 j = 0; j < tile_size_x; j += num_threads_x) { - IrArray::Index source_idx = - tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), - KernelMappingScheme::DimX, builder); - llvm::Value* x_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); - - ksl->IfReturnVoid( - "x_in_tile", builder->CreateICmpULT(x_loc, tile_width), [&] { + llvm::Value* start_offset_x; + int64 step_x; + std::tie(start_offset_x, step_x) = GetStartOffsetAndStepForX( + tile_size_x, num_threads_x, mapping_scheme, builder, x, index_ty); + IrArray::Index source_idx = + tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, builder) + .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, builder); + for (int64 j = 0; j < tile_size_x / num_threads_x; j++) { + IrArray::Index source_idx_x = + source_idx.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j * step_x), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = builder->CreateAdd( + llvm::ConstantInt::get(index_ty, j * step_x), start_offset_x); + + ksl->If( + loop_name + "_x_in_tile", builder->CreateICmpULT(x_loc, tile_width), + [&] { // tile_height_bound = // ceil(tile_height / num_threads_y) * num_threads_y llvm::Value* ceiling_of_ratio = builder->CreateUDiv( @@ -3252,20 +2299,19 @@ void EmitPartialTile( llvm::Value* tile_height_bound = builder->CreateMul( ceiling_of_ratio, llvm::ConstantInt::get(index_ty, num_threads_y)); - ksl->ForReturnVoid( + ksl->For( loop_name, /*start=*/llvm::ConstantInt::get(index_ty, 0), /*end=*/tile_height_bound, /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), [&](llvm::Value* y_indvar) { llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); - ksl->IfReturnVoid( - "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), - [&] { - emit_elem_function( - source_idx.AddOffsetToDim( - y_indvar, KernelMappingScheme::DimY, builder), - y_loc, x_loc); - }); + ksl->If(loop_name + "_y_in_tile", + builder->CreateICmpULT(y_loc, tile_height), [&] { + emit_elem_function( + source_idx_x.AddOffsetToDim( + y_indvar, KernelMappingScheme::DimY, builder), + y_loc, x_loc, j); + }); }); }); } @@ -3284,27 +2330,26 @@ void EmitTiledElementalCodeWithBoundsCheck( const IrArray::Index& tile_origin_index, const string& loop_name, KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, - const std::function& emit_elem_function) { + const EmitElementFunction& emit_elem_function) { int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); llvm::Type* index_ty = tile_width->getType(); - ksl->IfReturnVoid( - "full_tile", + ksl->If( + loop_name + "_full_tile", builder->CreateAnd( builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_x), tile_width), builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_y), tile_height)), [&] { - EmitFullTile(mapping_scheme, tile_origin_index, builder, y, x, index_ty, - emit_elem_function); + EmitFullElementalTile(mapping_scheme, tile_origin_index, loop_name, ksl, + builder, y, x, index_ty, emit_elem_function); }, [&] { - EmitPartialTile(mapping_scheme, tile_origin_index, loop_name, ksl, - builder, y, x, tile_height, tile_width, index_ty, - emit_elem_function); + EmitPartialElementalTile(mapping_scheme, tile_origin_index, loop_name, + ksl, builder, y, x, tile_height, tile_width, + index_ty, emit_elem_function); }); } } // namespace @@ -3321,7 +2366,7 @@ void EmitTiledElementalCodeWithBoundsCheck( void IrEmitterUnnested::EmitTileElementForCopy( HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { + llvm::Value* x_loc, int64 /*x_iter_num*/) { llvm_ir::TiledParameterInfo* tiled_param_info = kernel_info->GetTiledParameterInfo(); // TODO(jlebar): Add AA metadata to this load. @@ -3351,7 +2396,7 @@ void IrEmitterUnnested::EmitTileElementForCopy( void IrEmitterUnnested::EmitTileElementForFusion( HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { + llvm::Value* x_loc, int64 /*x_iter_num*/) { llvm_ir::TiledParameterInfo* tiled_param_info = kernel_info->GetTiledParameterInfo(); std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); @@ -3382,10 +2427,443 @@ void IrEmitterUnnested::EmitTileElementForFusion( } } -// Emits a block of tiles, given a function object to emit one tile. +// Information to support the code generation for a tiled reduction kernel. +using AddressVector = InlinedVector; +class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { + public: + explicit ReductionCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme, + bool is_row_reduction) + : KernelCodegenInfo(mapping_scheme), + current_output_linear_index_address_(nullptr), + current_output_inbound_address_(nullptr), + is_row_reduction_(is_row_reduction) {} + + void SetCurrentOutputLinearIndexAddress(llvm::AllocaInst* a) { + current_output_linear_index_address_ = a; + } + // Returns the address of the memory that stores the linear index of the + // current output. Since we are processing reduction to contiguous physical + // dimensions, this linear index is the linear index of the 1D output array. + llvm::AllocaInst* GetCurrentOutputLinearIndexAddress() const { + return current_output_linear_index_address_; + } + + void SetCurrentOutputInboundAddress(llvm::AllocaInst* a) { + current_output_inbound_address_ = a; + } + + llvm::AllocaInst* GetCurrentOutputInboundAddress() const { + return current_output_inbound_address_; + } + + AddressVector* GetMutablePartialResultAddresses() { + return &partial_result_addresses_; + } + absl::Span GetPartialResultAddresses() const { + return partial_result_addresses_; + } + + AddressVector* GetMutableReductionInputAddresses() { + return &reduction_input_addresses_; + } + absl::Span GetReductionInputAddresses() const { + return reduction_input_addresses_; + } + + InlinedVector* GetMutableReducers() { return &reducers_; } + const InlinedVector& GetReducers() const { + return reducers_; + } + int GetNumberOfReduces() const { return reducers_.size(); } + + InlinedVector* GetMutableReductionOutputShapeIndices() { + return &reduction_output_shape_indices_; + } + absl::Span GetReductionOutputShapeIndices() const { + return reduction_output_shape_indices_; + } + + bool IsRowReduction() const { return is_row_reduction_; } + + // Return the dimension that is being reduced between DimX and DimY. + int GetReducedDimensionEnum() const { + return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimX + : llvm_ir::KernelMappingScheme::DimY; + } + + // Return the dimension that is being ketp between DimX and DimY. + int GetKeptDimensionEnum() const { + return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimY + : llvm_ir::KernelMappingScheme::DimX; + } + + int GetNumberOfPartialResults() const { + if (IsRowReduction()) { + return 1; + } + int64 num_thread = mapping_scheme_->GetNumberOfThreadsForDimensionX(); + int64 tile_size = mapping_scheme_->GetTileSizeForDimensionX(); + CHECK_EQ(tile_size % num_thread, 0); + return tile_size / num_thread; + } + + int GetPartialResultIndex(int64 x_iter_num) const { + if (IsRowReduction()) { + return 0; + } + return x_iter_num; + } + + private: + AddressVector partial_result_addresses_; + AddressVector reduction_input_addresses_; + InlinedVector reducers_; + InlinedVector reduction_output_shape_indices_; + llvm::AllocaInst* current_output_linear_index_address_; + llvm::AllocaInst* current_output_inbound_address_; + bool is_row_reduction_; +}; + +namespace { +// Returns a group of instructions that generate the output for the kernel +// containing the given HLO instruction. The result may be an unnested kReduce +// HLO, a nested kReduce HLO of a kInput fusion, or the operands of the tuple +// for a multiple output fusion. +absl::Span GetOutputInstructions( + HloInstruction* const* reduce_or_tuple_pointer) { + HloOpcode opcode = (*reduce_or_tuple_pointer)->opcode(); + CHECK(opcode == HloOpcode::kReduce || opcode == HloOpcode::kTuple); + return opcode == HloOpcode::kTuple + ? (*reduce_or_tuple_pointer)->operands() + : absl::Span(reduce_or_tuple_pointer, 1); +} + +const HloInstruction* GetFirstReduceInstruction( + absl::Span instructions) { + auto first_reduce_iter = + absl::c_find_if(instructions, [](const HloInstruction* inst) { + return inst->opcode() == HloOpcode::kReduce; + }); + CHECK_NE(first_reduce_iter, instructions.end()); + return *first_reduce_iter; +} + +}; // namespace + +void IrEmitterUnnested::EmitPrologueForOneReduction( + HloInstruction* unnested_hlo, HloInstruction* reduce_inst, int reduce_idx, + KernelCodegenInfo* kernel_info, GpuElementalIrEmitter* elemental_emitter, + ShapeIndex output_shape_index) { + ReductionCodegenInfo* reduction_info = + static_cast(kernel_info); + + InlinedVector* reducers = + reduction_info->GetMutableReducers(); + CHECK(IsReductionToVector(*reduce_inst)); + reducers->push_back(reduce_inst->to_apply()); + + InlinedVector* reduction_output_shape_indices = + reduction_info->GetMutableReductionOutputShapeIndices(); + reduction_output_shape_indices->push_back(std::move(output_shape_index)); + + AddressVector* reduction_input_addresses = + reduction_info->GetMutableReductionInputAddresses(); + llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( + reduce_inst->shape().element_type(), ir_emitter_context_->llvm_module()); + llvm::AllocaInst* reduction_input_address = Alloca(element_type); + reduction_input_addresses->push_back(reduction_input_address); + + int num_partial_results = reduction_info->GetNumberOfPartialResults(); + AddressVector* partial_result_addresses = + reduction_info->GetMutablePartialResultAddresses(); + llvm::AllocaInst* partial_result_address = + Alloca(element_type, /*ArraySize=*/b_.getInt32(num_partial_results), + "partial_reduction_result." + llvm::Twine(reduce_idx)); + partial_result_addresses->push_back(partial_result_address); + + // Initialize the partial result with the initial value of the reduction. + llvm::Value* init_ir_value; + if (unnested_hlo->opcode() == HloOpcode::kFusion) { + HloInstruction* init_value_operand = reduce_inst->mutable_operand(1); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), + elemental_emitter); + + TF_CHECK_OK(init_value_operand->Accept(&fused_emitter)); + init_ir_value = + fused_emitter + .GetGenerator(init_value_operand)(IrArray::Index(b_.getInt32Ty())) + .ValueOrDie(); + } else { + const HloInstruction* init_value = unnested_hlo->operand(1); + init_ir_value = + GetIrArray(*init_value, *unnested_hlo) + .EmitReadArrayElement(IrArray::Index(b_.getInt32Ty()), &b_); + } + + for (int i = 0; i < num_partial_results; ++i) { + Store(init_ir_value, InBoundsGEP(partial_result_address, {b_.getInt32(i)})); + } +} + +void IrEmitterUnnested::EmitPrologueForReduction( + HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) { + VLOG(10) << "Emit prologue for reduction " << unnested_hlo->ToString(); + // Find the unnested kReduce or the tuple that contains a list of kReduce. + HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion + ? unnested_hlo->fused_expression_root() + : unnested_hlo; + absl::Span output_instructions = + GetOutputInstructions(&reduce_or_tuple); + ReductionCodegenInfo* reduction_info = + static_cast(kernel_info); + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, + ir_emitter_context_->llvm_module(), + &b_, GetNestedComputer()); + const HloInstruction* first_reduce = nullptr; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + if (output_instructions[i]->opcode() != HloOpcode::kReduce) { + continue; + } + HloInstruction* reduce_inst = output_instructions[i]; + if (first_reduce == nullptr) { + first_reduce = reduce_inst; + } else { + CHECK(first_reduce->dimensions() == reduce_inst->dimensions()); + } + ShapeIndex output_shape_index; + if (reduce_or_tuple->opcode() == HloOpcode::kTuple) { + output_shape_index = {i}; + } + + EmitPrologueForOneReduction(unnested_hlo, reduce_inst, i, kernel_info, + &elemental_emitter, + std::move(output_shape_index)); + } + + int num_partial_results = reduction_info->GetNumberOfPartialResults(); + + // Allocate stack storage to store the linear indices for the current output, + // and record the address of the storage. + reduction_info->SetCurrentOutputLinearIndexAddress( + Alloca(reduction_info->GetIndexType(), + /*ArraySize=*/b_.getInt32(num_partial_results), + "current_output_linear_index_address")); + + if (!reduction_info->IsRowReduction()) { + llvm::Type* bool_ty = b_.getInt1Ty(); + llvm::AllocaInst* output_inbound_addr = Alloca(bool_ty); + Store(llvm::ConstantInt::get(bool_ty, 0), output_inbound_addr); + reduction_info->SetCurrentOutputInboundAddress(output_inbound_addr); + } +} + +void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces( + absl::Span reducers, + absl::Span partial_result_addresses) { + for (int distance = 16; distance >= 1; distance /= 2) { + for (int i = 0; i != reducers.size(); ++i) { + llvm::Type* element_type = + partial_result_addresses[i]->getType()->getElementType(); + int bit_width = llvm_ir::GetSizeInBits(element_type); + llvm::Value* result_from_other_lane = Alloca( + element_type, nullptr, "result_from_other_lane" + llvm::Twine(i)); + // Bitcast cannot be applied to aggregate types (even packed ones), so + // we bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffled_value_type = + element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type; + auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { + return BitCast(ptr, shuffled_value_type->getPointerTo()); + }; + llvm::Value* partial_result = + Load(convert_pointer_for_shuffle(partial_result_addresses[i]), + "partial_reduction_result"); + Store(EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_), + convert_pointer_for_shuffle(result_from_other_lane)); + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], {partial_result_addresses[i], result_from_other_lane}, + partial_result_addresses[i])); + } + } +} + +void IrEmitterUnnested::EmitEpilogueForReduction( + HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) { + ReductionCodegenInfo* reduction_info = + static_cast(kernel_info); + int num_reduces = reduction_info->GetNumberOfReduces(); + absl::Span partial_result_addresses = + reduction_info->GetPartialResultAddresses(); + const InlinedVector& reducers = + reduction_info->GetReducers(); + absl::Span reduction_output_shape_indices = + reduction_info->GetReductionOutputShapeIndices(); + + if (reduction_info->IsRowReduction()) { + EmitFullWarpShuffleDownLoopForAllReduces(reducers, + partial_result_addresses); + llvm::Value* lane_id = reduction_info->GetLaneId(); + llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( + ICmpEQ(lane_id, llvm::ConstantInt::get(lane_id->getType(), 0)), + "lane_id_is_zero", &b_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); + } else { + llvm::Value* output_inbound_addr = + reduction_info->GetCurrentOutputInboundAddress(); + llvm::Value* output_inbound = Load(output_inbound_addr); + llvm_ir::LlvmIfData if_output_inbound_data = llvm_ir::EmitIfThenElse( + ICmpEQ(output_inbound, + llvm::ConstantInt::get(output_inbound->getType(), 1)), + "output_inbound", &b_); + llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_); + } + + int num_partial_results = reduction_info->GetNumberOfPartialResults(); + + // Emit an atomic operation that accumulates the partial reduction to the + // output element. For row reduction, this is only for lane 0 due to the + // if-statement emitted above. + for (int i = 0; i != num_reduces; ++i) { + for (int j = 0; j < num_partial_results; ++j) { + IrArray::Index element_index( + /*linear=*/Load( + InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(), + {b_.getInt32(j)}), + "output_linear_addr"), + ShapeUtil::GetSubshape(unnested_hlo->shape(), + reduction_output_shape_indices[i]), + &b_); + llvm::Value* output_address = + GetIrArray(*unnested_hlo, *unnested_hlo, + reduction_output_shape_indices[i]) + .EmitArrayElementAddress(element_index, &b_, + "output_element_address"); + // Do not emit atomic operations if each element in the reduction result + // is computed by one block, that is the dimension being reduced has only + // one block. + const llvm_ir::KernelMappingScheme* mapping_scheme = + reduction_info->GetKernelMappingScheme(); + if (mapping_scheme->GetTileBlockSizeForDimension( + llvm_ir::KernelMappingScheme::DimZ) == 1 && + mapping_scheme->GetTileBlockSizeForDimension( + reduction_info->GetReducedDimensionEnum()) == 1) { + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], + {output_address, + InBoundsGEP(partial_result_addresses[i], {b_.getInt32(j)})}, + output_address)); + } else { + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + InBoundsGEP(partial_result_addresses[i], {b_.getInt32(j)}))); + } + } + } +} + +void IrEmitterUnnested::EmitTileElementForReduction( + HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc, int64 x_iter_num) { + VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString(); + HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion + ? unnested_hlo->fused_expression_root() + : unnested_hlo; + llvm_ir::TiledParameterInfo* tiled_param_info = + kernel_info->GetTiledParameterInfo(); + tiled_param_info->set_y(y_loc); + tiled_param_info->set_x(x_loc); + + // Record the linear address for the current reduction. + const ReductionCodegenInfo* reduction_info = + dynamic_cast(kernel_info); + int partial_result_index = reduction_info->IsRowReduction() ? 0 : x_iter_num; + + Store(index[reduction_info->GetKeptDimensionEnum()], + InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(), + {b_.getInt32(partial_result_index)})); + if (!reduction_info->IsRowReduction()) { + llvm::Type* bool_ty = b_.getInt1Ty(); + llvm::AllocaInst* output_inbound_addr = + reduction_info->GetCurrentOutputInboundAddress(); + Store(llvm::ConstantInt::get(bool_ty, 1), output_inbound_addr); + } + + InlinedVector input_gens; + std::vector> + extra_output_gens; + GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), + &elem_emitter); + absl::Span output_instructions = + GetOutputInstructions(&reduce_or_tuple); + // Construct the ElementGenerator for each reduction and extra output in the + // the group of output instructions. + if (unnested_hlo->opcode() == HloOpcode::kFusion) { + fused_emitter.SetTiledParameterInfo(tiled_param_info); + TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter)); + + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + const HloInstruction* inst = output_instructions[i]; + ShapeIndex output_shape_index; + if (reduce_or_tuple->opcode() == HloOpcode::kTuple) { + output_shape_index = {i}; + } + if (inst->opcode() == HloOpcode::kReduce) { + input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); + } else { + extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), + std::move(output_shape_index)); + } + } + } else { + input_gens.push_back([&](const IrArray::Index& index) { + return GetIrArray(*unnested_hlo->operand(0), *unnested_hlo) + .EmitReadArrayElement(index, &b_); + }); + } + + IrArray::Index input_index = + reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex( + index, + GetFirstReduceInstruction(output_instructions)->operand(0)->shape()); + int num_partial_results = reduction_info->GetNumberOfPartialResults(); + if (num_partial_results > 1) { + // Clear the linear index field of the IrArray::Index to enable the use of + // GetElementPointer with array types. This enables the vectorization of + // the computation for different partial results. + input_index.ClearLinearIndex(); + } + absl::Span partial_reduction_result_addresses = + reduction_info->GetPartialResultAddresses(); + absl::Span reduction_input_addresses = + reduction_info->GetReductionInputAddresses(); + const InlinedVector& reducers = + reduction_info->GetReducers(); + + // Emit code to generate the input and perform the reduction computation for + // each reduction instruction. + for (int i = 0; i != reducers.size(); ++i) { + llvm::Value* const input_ir_value = input_gens[i](input_index).ValueOrDie(); + Store(input_ir_value, reduction_input_addresses[i]); + llvm::Value* partial_result_address = + InBoundsGEP(partial_reduction_result_addresses[i], + {b_.getInt32(partial_result_index)}); + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], {partial_result_address, reduction_input_addresses[i]}, + partial_result_address)); + } + + // Emit code to generate the output for the non-reduction instructions in the + // fusion, if any. + TF_CHECK_OK( + EmitExtraOutputsForReduce(unnested_hlo, input_index, extra_output_gens)); +} + +// Emits a kernel for the hlo instruction using the given tiling scheme. void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, - const KernelCodegenInfo* kernel_info, - KernelSupportLibrary& ksl, + KernelCodegenInfo* kernel_info, + KernelSupportLibrary* ksl, llvm::Type* index_ty) { KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); absl::Span dims_in_tile = mapping_scheme->GetDimensionsInTiles(); @@ -3418,16 +2896,14 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, llvm::Value* num_tiles_in_block = Select(ICmpEQ(last_block_for_dim, block_id_for_dim), last_block_size_for_dim, block_size_for_dim); - - ksl.ForReturnVoid( - loop_name, - /*start=*/index_typed_constant(0), - /*end=*/num_tiles_in_block, - /*step=*/1, [&](llvm::Value* block_dim_induction_var) { - IrArray::Index tile_index = starting_tile.AddOffsetToDim( - block_dim_induction_var, dim_id, &b_); - emit_next_block_dim(tile_index); - }); + ksl->For(loop_name, + /*start=*/index_typed_constant(0), + /*end=*/num_tiles_in_block, + /*step=*/1, [&](llvm::Value* block_dim_induction_var) { + IrArray::Index tile_index = starting_tile.AddOffsetToDim( + block_dim_induction_var, dim_id, &b_); + emit_next_block_dim(tile_index); + }); } }; @@ -3482,7 +2958,8 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, // unnested_hlo: The unnested hlo instruction for which the kernel is generated. // Currently, these hlo instructions are supported: kLoop fusion, kCopy. // tiled_param_ids: The IDs for the parameters that are 0-2-1 transpose of -// other tensors with the same dimensions and need to be tiled and tranposed. +// other tensors with the same dimensions and are safe to be tranposed via +// the shared memory tranpose implementation. // mapping_scheme: The tiling scheme to use. // kernel_generator: Contains function objects for code generation, such as // element generator, block prologue and epilogue generators. @@ -3509,11 +2986,22 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( << llvm_ir::DumpToString(*param_shmem_buffers[id]); } - CHECK_EQ(mapping_scheme->GetThreadsPerTile() % kWarpSize, 0); - LaunchDimensions launch_dimensions = LaunchDimensions( - mapping_scheme->GetNumberOfBlocks(), mapping_scheme->GetThreadsPerTile()); - llvm::Type* index_ty = GetIndexTypeForKernel( - unnested_hlo, launch_dimensions.launch_bound(), &b_); + const ReductionCodegenInfo* reduction_info = + dynamic_cast(kernel_info); + bool is_column_reduction = + (reduction_info && !reduction_info->IsRowReduction()); + + LaunchDimensions launch_dimensions = + LaunchDimensions(mapping_scheme->GetNumberOfBlocks(), + mapping_scheme->GetThreadsPerBlock()); + + // TODO(b/110211620): Enable int32 index type for column reduction. + llvm::Type* index_ty = + is_column_reduction + ? b_.getInt64Ty() + : GetIndexTypeForKernel(unnested_hlo, + launch_dimensions.launch_bound(), &b_); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; @@ -3523,14 +3011,12 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( // but we do it at the beginning in the hopes of reducing register pressure, // since we touch threadIdx.x and blockIdx.x at the beginning of the kernel // *anyway*. - if (unnested_hlo->IsMultiOutputFusion()) { - TF_CHECK_OK(KernelSupportLibrary(&b_).If( - "emit_mof_tuple", IsBlock0Thread0(&b_), [&] { - llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), - ConstructIrArrayForOutputs(*unnested_hlo), &b_, - module_); - return Status::OK(); - })); + if (!reduction_info && unnested_hlo->IsMultiOutputFusion()) { + KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), + ConstructIrArrayForOutputs(*unnested_hlo), &b_, + module_); + }); } // For each tiled parameter, cast its input IrArray to the corresponding @@ -3553,14 +3039,14 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( kernel_info->SetLaneId( mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x : nullptr); + kernel_info->SetIndexType(index_ty); KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck. auto emit_tiled_elemental_code_with_bounds_check = [&](const IrArray::Index& index, const string& loop_name, llvm::Value* tile_height, llvm::Value* tile_width, - const std::function& emit_elem_function) { + const EmitElementFunction& emit_elem_function) { EmitTiledElementalCodeWithBoundsCheck(mapping_scheme, index, loop_name, &ksl, &b_, y, x, tile_height, tile_width, emit_elem_function); @@ -3573,52 +3059,49 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( const IrArray::Index input_tile_origin( Permute({0, 2, 1}, output_tile_origin.multidim())); - const IrArray::Index input_index = - input_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) - .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); - - // Copy input parameter values to shared memory buffers: - // tile[y, x] = input[index] - // Note that tile_width and tile_height are flipped here because we are - // reading a transposed tile. - emit_tiled_elemental_code_with_bounds_check( - input_index, "input", output_tile_bounds[2], output_tile_bounds[1], - [&](const IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - for (int64 id : tiled_param_ids) { - IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; - llvm::Value* shmem_buffer = param_shmem_buffers[id]; - // TODO(jlebar): Add AA metadata to this store. Tile buffers are - // global variables, so LLVM can't infer much about it. - Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, - "input_element"), - GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc})); - } - }); - // If shared memory transpose is needed, wait for all threads to reach this // point, lest we copy a value from tile to output before the other thread // copies it from input to tile. This is `__syncthreads` in CUDA. if (!tiled_param_ids.empty()) { + // Copy input parameter values to shared memory buffers: + // tile[y, x] = input[index] + // Note that tile_width and tile_height are flipped here because we are + // reading a transposed tile. + emit_tiled_elemental_code_with_bounds_check( + input_tile_origin, "input", output_tile_bounds[2], + output_tile_bounds[1], + [&](const IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc, int64 /*x_iter_num*/) { + for (int64 id : tiled_param_ids) { + IrArray& input_in_logical_shape = + param_in_reduced_shape_arrays[id]; + llvm::Value* shmem_buffer = param_shmem_buffers[id]; + // TODO(jlebar): Add AA metadata to this store. Tile buffers are + // global variables, so LLVM can't infer much about it. + Store(input_in_logical_shape.EmitReadArrayElement( + index, &b_, "input_element"), + GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc})); + } + }); + + // Wait for all threads to reach this point using `__syncthreads` in CUDA. llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); } llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); kernel_info->SetTiledParamInfo(&tiled_param_info); - const IrArray::Index output_index = - output_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) - .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); - // Write to output[index] by emitting code like normal, except that values // for the tiled parameters are read from the shmem buffers. emit_tiled_elemental_code_with_bounds_check( - output_index, "output", output_tile_bounds[1], output_tile_bounds[2], - [&](const IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - kernel_generator.GetTileElementGenerator()(unnested_hlo, index, - kernel_info, y_loc, x_loc); + output_tile_origin, "output", output_tile_bounds[1], + output_tile_bounds[2], + [&](const IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num) { + kernel_generator.GetTileElementGenerator()( + unnested_hlo, index, kernel_info, y_loc, x_loc, x_iter_num); }); + // If a tile block contains multiple tiles and shared memory buffers are // used, we need to wait for all threads to finish using the shared memory // buffer for the current tile before we move on to process the next tile @@ -3634,7 +3117,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( block_prologue_generator(unnested_hlo, kernel_info); } - EmitBlock(std::move(emit_one_tile), kernel_info, ksl, index_ty); + EmitBlock(std::move(emit_one_tile), kernel_info, &ksl, index_ty); const BlockEpilogueGenerator& block_epilogue_generator = kernel_generator.GetBlockEpilogueGenerator(); @@ -3647,7 +3130,10 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose // algorithm to improve the memory access patterns for the input parameters -// with a shape that is a 0-2-1 transpose of the output tensor shape. +// with a shape that is a 0-2-1 transpose of the output tensor shape. The caller +// is responsible for making sure that it is safe to apply the shared memory +// tranpose on the input parameters. +// // // For the purpose of tiling, the output tensors have a logical shape of three // components 0-2-1 while the relevant input parameters have a logical shape @@ -3680,17 +3166,19 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( element_generator = [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc) { - EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num) { + EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc, x_iter_num); }; } else { DCHECK_EQ(hlo->opcode(), HloOpcode::kFusion); - element_generator = [&](HloInstruction* hlo, - const llvm_ir::IrArray::Index& index, - const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc) { - EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc); - }; + element_generator = + [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc, int64 x_iter_num) { + EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc, + x_iter_num); + }; } KernelCodegenInfo kernel_info(&mapping_scheme); KernelCodeGenerator kernel_generator(std::move(element_generator)); @@ -3698,26 +3186,99 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( } namespace { -// Returns true to indicate it is safe to use the tile based shared memory -// transpose implementation to implement the kernel for the instruction. +// A recursive function to inspect the users of a parameter to determine +// whether it's safe for a parameter to participate in a shared-memory +// transpose. // -// An instruction is not safe for such an implementation if it can change the -// element order of a tensor without changing the dimension of the tensor, and -// the instruction has a corresponding elemental_ir_emitter. -bool IsInstructionSafeForTileBasedTranspose(const HloInstruction* hlo) { - auto is_safe_for_tile_based_transpose = [&](const HloInstruction* instr) { - HloOpcode opcode = instr->opcode(); - CHECK_NE(opcode, HloOpcode::kFusion); - return (opcode != HloOpcode::kReverse && opcode != HloOpcode::kGather); - }; +// Consider a fusion parameter P for which we might want to use a shmem +// transpose. If we do, we use a GPU thread block to preload a tile of P with +// indices [z, y..y+31, x..x+31] to compute an output tile with the same indices +// cooperatively, where z, y, x are the indices for the normalized input/output +// tensor (see the document for FindTranspose021 for the definition of +// normalized tensor for 0-2-1 transpose). This shmem transpose implementation +// requires that the computation of the output tile only read elements within +// the preload tile. If this is not true, we can't use a shmem transpose for P. +// +// If the computation of output element [z, y, x] only requires the element of +// P with the same indices, the shmem tranpose implementation can be applied +// to P safely. This is a sufficient but not necessary condition. We check all +// the transitive users of P to see if we can find a user that may cause an +// exception to the situation. If such a user is not found, we conclude that P +// is safe for shmem transpose. +// +// This is trivially true for elementwise operations and some "data-movement" +// ops like kTuple. However, it's not true for operations that can change the +// dimensions of the inputs (e.g. pad, slice) and bitcast operation. +// For example: +// +// fused_computation { +// param_0 = f32[64,64]{1,0} parameter(0) +// ROOT bitcast = f32[64,64]{0,1} bitcast(param_0) +// } +// The output element at logical address [0, 63] depends on the input element +// at logical address [63, 0], which would not be within the shared-memory +// block. +// +// TODO(bixia): In order to extend this for kInput fusion, that is reduction +// with tranpose, we only need to end the use-chain checking with the input of +// a reduce operations. In this case, the above description on "output" apply +// to the result of such a use-chain, which provides the input to the reduce +// operation. +bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) { + if (hlo->IsElementwise()) { + return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { + return IsInstructionSafeForShmemTranspose(user); + }); + } + + switch (hlo->opcode()) { + // Non-elementwise instructions that don't cause the shmem transpose + // to be unsafe, including the instructions that don't currently fuse. + case HloOpcode::kGetDimensionSize: + // The result of the operation doesn't rely on the content of the + // tensor. As such, there is no need to further inspect its users. + return true; + case HloOpcode::kGetTupleElement: + case HloOpcode::kMap: + case HloOpcode::kParameter: + case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: + return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { + return IsInstructionSafeForShmemTranspose(user); + }); - if (hlo->opcode() == HloOpcode::kFusion) { - return absl::c_all_of(hlo->fused_instructions_computation()->instructions(), - is_safe_for_tile_based_transpose); + default: + return false; } +} - return is_safe_for_tile_based_transpose(hlo); +// Given a group of input parameters that are 0-2-1 tranpose of the outputs of +// a fusion kernel, returns the input parameters that are safe for the shared +// memory tranpose implementation. +// +// When a tile based shared memory transpose is used to implement an input with +// 0-2-1 transpose, we preload a tile of the input elements +// [z, y..y+31, x..x+31] to compute the output tile elements of the same +// indices. Preloading the input tile this way is only safe when the computation +// of the output tile elements do not need any input element outside the +// preloaded tile. We inspect all the transitive users of the input parameter +// up to the fusion root instruction to see if we can find any instruction +// that can make preloading the input tile unsafe. +std::vector FilterInputsForShmemTranspose(const HloInstruction* fusion, + std::vector input_ids) { + std::vector filtered_input_ids; + for (int64 i = 0; i < input_ids.size(); ++i) { + const HloInstruction* input = fusion->fused_parameter(input_ids[i]); + if (IsInstructionSafeForShmemTranspose(input)) { + filtered_input_ids.push_back(input_ids[i]); + } else { + VLOG(10) << "Input not safe for shmem transpose " << input->ToString() + << "\n"; + } + } + return filtered_input_ids; } + } // namespace bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { @@ -3764,8 +3325,11 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return false; } - if (!IsInstructionSafeForTileBasedTranspose(hlo)) { - return false; + if (opcode == HloOpcode::kFusion) { + params_012 = FilterInputsForShmemTranspose(hlo, params_012); + if (params_012.empty()) { + return false; + } } // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the @@ -3814,6 +3378,350 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return true; } +namespace { +// Checks that the outputs of a fusion with reduction are consistent. +Status AreFusedReductionOutputsConsistent( + absl::Span output_instructions, + const HloInstruction* first_reduce) { + for (const HloInstruction* inst : output_instructions) { + if (inst->opcode() == HloOpcode::kReduce) { + // Shapes, layouts and dimensions must be the same for all reduces + // inside of this fusion. + TF_RET_CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); + TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), + inst->operand(0)->shape())); + TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), + inst->operand(1)->shape())); + TF_RET_CHECK(first_reduce->dimensions() == inst->dimensions()); + } else { + // For extra outputs we can relax shape equality to allow different + // types (with the same number of elements). Layouts still have to + // match. + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringElementType( + first_reduce->operand(0)->shape(), inst->shape())); + TF_RET_CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), + inst->shape().layout())); + } + } + return Status::OK(); +} + +// Finds the dimensions to keep for the reduction, sorts and returns the +// dimensions from minor to major. +DimensionVector GetDimensionsToKeepMinorToMajor( + const Shape& input_shape, absl::Span dims_to_reduce) { + DimensionVector input_dims(input_shape.rank(), 0); + absl::c_iota(input_dims, 0); + DimensionVector input_dims_to_keep; + for (int input_dim : input_dims) { + auto it = absl::c_find_if(dims_to_reduce, [&](int64 dim_to_reduce) { + return dim_to_reduce == input_dim; + }); + if (it == dims_to_reduce.end()) { + input_dims_to_keep.push_back(input_dim); + } + } + + // Sort the dimensions to keep from minor to major. + absl::c_sort(input_dims_to_keep, [&input_shape](int64 dim_a, int64 dim_b) { + return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_a) < + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_b); + }); + + VLOG(10) << "dims to keep minor to major" + << absl::StrJoin(input_dims_to_keep, ","); + return input_dims_to_keep; +} + +// Given the input shape and dimensions to reduce for the reduction to vector, +// returns : +// num_kept: the number of elements in the contiguous dimensions to keep. +// num_reduced_major: the number of elements in the dimensions to reduce that +// are more major than the dimensions to keep. +// num_reduced_minor: the number of elements in the dimensions to reduce that +// are more minor than the dimensions to kept. +std::tuple GetReductionToVectorDimensions( + const Shape& input_shape, absl::Span dims_to_reduce) { + DimensionVector input_dims_to_keep_minor_to_major = + GetDimensionsToKeepMinorToMajor(input_shape, dims_to_reduce); + CHECK(LayoutUtil::AreDimensionsConsecutive( + input_shape.layout(), input_dims_to_keep_minor_to_major)); + int num_reduced_major = 1, num_kept = 1, num_reduced_minor = 1; + if (input_dims_to_keep_minor_to_major.empty()) { + return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); + } + DimensionVector input_dims(input_shape.rank(), 0); + absl::c_iota(input_dims, 0); + absl::Span minor_to_major = + LayoutUtil::MinorToMajor(input_shape); + for (int input_dim : input_dims) { + int64 curr_dim_size = input_shape.dimensions(input_dim); + if (PositionInContainer(minor_to_major, input_dim) > + PositionInContainer(minor_to_major, + input_dims_to_keep_minor_to_major.back())) { + num_reduced_major *= curr_dim_size; + } else if (PositionInContainer(minor_to_major, input_dim) < + PositionInContainer(minor_to_major, + input_dims_to_keep_minor_to_major.front())) { + num_reduced_minor *= curr_dim_size; + } else { + num_kept *= curr_dim_size; + } + } + + return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); +} + +// Returns true if all the transitive users of hlo before hitting users in +// use_chain_endings are elementwise operations. +bool AreUsersElementwise(const HloInstruction* hlo, + const ConstHloInstructionSet& use_chain_endings) { + return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { + return use_chain_endings.count(user) || + (user->IsElementwise() && + AreUsersElementwise(user, use_chain_endings)); + }); +} + +// Returns the number of fusion inputs that have the same dimension as the +// given shape, and involve in only elementwise operations. +int64 NumInputsInvolveInOnlyElementwiseOps( + const HloInstruction* unnested_hlo, const Shape& op_shape, + const ConstHloInstructionSet& use_chain_endings) { + return absl::c_count_if( + unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) { + const Shape& parameter_shape = parameter->shape(); + return ShapeUtil::SameDimensions(op_shape, parameter_shape) && + AreUsersElementwise(parameter, use_chain_endings); + }); +} + +// Returns the number of fusion inputs that have more elements than the given +// shape. +int64 NumInputsWithMoreElementsThan(const HloInstruction* unnested_hlo, + const Shape& shape) { + int64 num_elements = ShapeUtil::ElementsIn(shape); + return absl::c_count_if( + unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) { + return ShapeUtil::ElementsIn(parameter->shape()) > num_elements; + }); +} + +// The benefit of unrolling a kInput fusion that is a column reduction comes +// from the vectorization of non-reduction fusion outputs and fusion inputs. +// On the other hand, unrolling can also introduce factors that can cause +// the kernel to run slower. This routine uses a simple heuristic to estimate +// the benefit as well as the overhead of unrolling in order to decide whether +// unrolling is beneficial for the given kInput fusion. +bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo, + const Shape& input_shape, + int64 num_kept) { + // TODO(b/122468062): Need further investigate to see whether we can + // remove the constraint on IsPowerOfTwo. + if (!IsPowerOfTwo(static_cast(num_kept))) { + return false; + } + + if (unnested_hlo->opcode() == HloOpcode::kReduce) { + return true; + } + + CHECK_EQ(unnested_hlo->opcode(), HloOpcode::kFusion); + int64 can_be_vectorized = 0; + int64 cannot_be_vectorized = 0; + const HloInstruction* fused_root = unnested_hlo->fused_expression_root(); + ConstHloInstructionSet use_chain_endings; + if (fused_root->opcode() == HloOpcode::kReduce) { + use_chain_endings.insert(fused_root); + // Atomic.add of the reduction result can't be vectorized. + cannot_be_vectorized++; + } else { + CHECK_EQ(fused_root->opcode(), HloOpcode::kTuple); + for (const HloInstruction* instr : fused_root->operands()) { + if (instr->opcode() == HloOpcode::kReduce) { + // Atomic.add of the reduction result can't be vectorized. + cannot_be_vectorized++; + } else { + // Write of the non-reduction result can be vectorized. + can_be_vectorized++; + } + use_chain_endings.insert(instr); + } + } + // Fusion inputs that have the same dimension as the reduce input and + // only involve in elementwise operations can be vectorized. + can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps( + unnested_hlo, input_shape, use_chain_endings); + // Fusion inputs with more elements than the reduce op input must participate + // in non-elementwise operations and we assume that they are not vectorizable + // for the purpose of estimating the benefit of unrolling. If the kernel is + // unrolled even with such an assumption, and the accesses to those inputs + // turn out to be vectorizable, the compiler will still vectorize them. + cannot_be_vectorized += + NumInputsWithMoreElementsThan(unnested_hlo, input_shape); + return can_be_vectorized >= cannot_be_vectorized; +} + +} // namespace + +std::tuple +IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( + const HloInstruction* unnested_hlo, const HloInstruction* first_reduce) { + int64 depth = 1; + int64 height = 1; + int64 width = 1; + bool is_row_reduction = true; + int64 tile_size_x = 1; + int64 tile_size_y = 1; + int64 block_size_z = 1; + int64 num_threads_x = 1; + int64 num_threads_y = 1; + const Shape& input_shape = first_reduce->operand(0)->shape(); + int64 num_input_elems = ShapeUtil::ElementsIn(input_shape); + int64 num_output_elems = ShapeUtil::ElementsIn(first_reduce->shape()); + int64 num_reduced_major, num_kept, num_reduced_minor; + std::tie(num_reduced_major, num_kept, num_reduced_minor) = + GetReductionToVectorDimensions(input_shape, first_reduce->dimensions()); + CHECK_EQ(num_output_elems, num_kept); + bool dilated_x = true; + + if (num_kept == 1) { + // Scalar reduction is a special row reduction with depth = height = 1. + width = num_input_elems; + tile_size_x = kWarpSize * 16; + num_threads_x = kWarpSize; + } else if (num_reduced_minor == 1) { + // Column reduction reduces inputs with dimension [height, width], where + // width is the minor dimension, to dimension [width]. + height = num_reduced_major; + width = num_kept; + is_row_reduction = false; + // Column reduction without transpose doesn't require communication among + // threads processing elements in the same tile. The current implementation + // only support the use of one hardware thread block to process one block of + // tiles in the KernelMappingScheme. We try to use one thread to compute + // the partial results for two tensor elements and to maximize the values of + // num_threads_x and tile_size_x to allow a bigger hardware thread block. + int64 hw_threads_per_block_limit = + ThreadsPerBlockLimit(ir_emitter_context_->device_description()); + if (IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape, + num_kept)) { + tile_size_x = std::min(2 * hw_threads_per_block_limit, num_kept); + num_threads_x = tile_size_x / 2; + dilated_x = false; + } else { + tile_size_x = std::min(hw_threads_per_block_limit, num_kept); + num_threads_x = tile_size_x; + } + int64 kNumElementsPerPartialSum = 128; + tile_size_y = kNumElementsPerPartialSum; + } else { + // Row reduction reduces inputs with dimension [depth, height, width], + // where width is the most minor dimension, to dimension [height] . + depth = num_reduced_major; + height = num_kept; + width = num_reduced_minor; + num_threads_x = kWarpSize; + if (width % (kWarpSize * 64) == 0) { + tile_size_x = kWarpSize * 64; + } else { + tile_size_x = kWarpSize * 8; + block_size_z = 8; + while (depth % block_size_z != 0) { + block_size_z -= 1; + } + } + } + DCHECK_EQ(depth * height * width, num_input_elems); + VLOG(10) << "is_row_reduction " << is_row_reduction << depth << " " << height + << " " << width; + + DimensionVector dims_in_elem{depth, height, width}; + DimensionVector req_block_sizes{block_size_z, 1, 1}; + llvm_ir::KernelMappingScheme mapping_scheme( + dims_in_elem, tile_size_y, tile_size_x, req_block_sizes, num_threads_y, + num_threads_x, &b_); + mapping_scheme.SetDilatedX(dilated_x); + return std::make_tuple(mapping_scheme, is_row_reduction); +} + +Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) { + VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); + + HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion + ? unnested_hlo->fused_expression_root() + : unnested_hlo; + absl::Span output_instructions = + GetOutputInstructions(&reduce_or_tuple); + const HloInstruction* first_reduce = + GetFirstReduceInstruction(output_instructions); + + if (output_instructions.size() > 1) { + TF_RETURN_IF_ERROR( + AreFusedReductionOutputsConsistent(output_instructions, first_reduce)); + } + + // Build an initializer thunk to initialize each reduction output. + std::vector> thunks; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + if (output_instructions[i]->opcode() != HloOpcode::kReduce) { + continue; + } + TF_ASSIGN_OR_RETURN( + std::unique_ptr initializer_thunk, + BuildInitializerThunk(unnested_hlo, + (output_instructions[i] == reduce_or_tuple) + ? ShapeIndex() + : ShapeIndex({i}))); + thunks.push_back(std::move(initializer_thunk)); + } + + // Build a kernel thunk to compute all the outputs. + std::unique_ptr kernel_thunk = + BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false); + + const Shape& input_shape = first_reduce->operand(0)->shape(); + // The layout of a reduction input is either set by LayoutAssignment for + // unnested kReduce or by InstructionFusion for fused kReduce. + CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " + "doesn't set the input layout of " + << first_reduce->ToString(); + + bool is_row_reduction; + llvm_ir::KernelMappingScheme mapping_scheme; + std::tie(mapping_scheme, is_row_reduction) = + ComputeMappingSchemeAndReductionKind(unnested_hlo, first_reduce); + ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction); + KernelCodeGenerator kernel_generator( + /*tile_element_generator=*/ + [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc, int64 x_iter_num) { + EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc, + x_iter_num); + }, + /*block_prologue_generator=*/ + [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { + EmitPrologueForReduction(hlo, kernel_info); + }, + /*block_epilogue_generator*/ + [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { + EmitEpilogueForReduction(hlo, kernel_info); + }); + + LaunchDimensions launch_dimensions = + EmitKernel(unnested_hlo, {}, kernel_generator, &reduction_info); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), + ir_emitter_context_->llvm_module()); + + thunks.push_back(std::move(kernel_thunk)); + std::unique_ptr sequential_thunk = + absl::make_unique(std::move(thunks), unnested_hlo); + AddThunkToThunkSequence(std::move(sequential_thunk)); + + return Status::OK(); +} + Status IrEmitterUnnested::EmitConstantGlobals() { for (const BufferAllocation& allocation : ir_emitter_context_->buffer_assignment().Allocations()) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index e09ed657a812be6ab4859a0e365a51c45a37bfed..f85e18bbf0798ef3d5b87e81d287d8aed691dfc4 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" @@ -68,11 +69,13 @@ class IrEmitterUnnested : public IrEmitter { explicit KernelCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme) : mapping_scheme_(mapping_scheme), tiled_param_info_(nullptr), - lane_id_(nullptr) {} + lane_id_(nullptr), + index_ty_(nullptr) {} + virtual ~KernelCodegenInfo() {} void SetLaneId(llvm::Value* v) { lane_id_ = v; } + void SetIndexType(llvm::Type* t) { index_ty_ = t; } void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) { - CHECK_EQ(tiled_param_info_, nullptr); tiled_param_info_ = tiled_param_info; } @@ -83,11 +86,13 @@ class IrEmitterUnnested : public IrEmitter { llvm_ir::TiledParameterInfo* GetTiledParameterInfo() const { return tiled_param_info_; } + llvm::Type* GetIndexType() const { return index_ty_; } - private: + protected: llvm_ir::KernelMappingScheme* mapping_scheme_; llvm_ir::TiledParameterInfo* tiled_param_info_; llvm::Value* lane_id_; + llvm::Type* index_ty_; }; // A function object to prepare for the code generation for a tile block. @@ -103,10 +108,12 @@ class IrEmitterUnnested : public IrEmitter { // y_loc: The y coordinate within a tile. // x_loc: The x coordinate within a tile. // kernel_info: Other information to support the kernel code generation. + // x_iter_num: When a thread process N elements in the X dimension, x_iter_num + // has a value of 0..N-1 to identify the element being process. using TileElementGenerator = std::function; + llvm::Value* x_loc, int64 x_iter_num)>; // KernelCodeGenerator records the code generator objects that generate code // for tile elements or tile block prologue/epilogue. @@ -169,8 +176,9 @@ class IrEmitterUnnested : public IrEmitter { Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; + Status HandleTriangularSolve(HloInstruction* hlo) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllReduce(HloInstruction* crs) override; Status HandleAfterAll(HloInstruction* after_all) override; Status EmitTargetElementLoop( @@ -200,82 +208,23 @@ class IrEmitterUnnested : public IrEmitter { // Helper for writing extra outputs from inside a reduce kernel. Status EmitExtraOutputsForReduce( - const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + const HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, absl::Span> extra_output_gens); - // EmitColumnReduction and EmitRowReduction emit code for column and row - // reduction of a matrix and/or 3D tensor. Row and column reduction have - // different memory access pattern, so for performance their implementations - // are significantly different. + // Generates code for reduction to contiguous dimensions. // - // Emits code that reduces a matrix of shape [height x width] to a vector of - // [width]. Other parameters have the same meaning as those of - // `EmitReductionToVector`. Note that input shape might not be - // [height x width], but can be bitcast to [height x width] with "height" - // being the major dimension. - Status EmitColumnReduction( - KernelThunk* kernel_thunk, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); - - // Emits code that reduces a 3D tensor of shape [depth x height x width] to a - // vector of shape [height]. Other parameters have the same meaning as those - // of `EmitReductionToVector`. Note that input shape might not be - // [depth x height x width], but can be bitcast to [depth x height x width] - // with "depth" being the most major dimension. - Status EmitRowReduction( - KernelThunk* kernel_thunk, int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); - - // Emits code that reduces a tensor of arbitrary rank to a scalar. - Status EmitReductionToScalar( - KernelThunk* kernel_thunk, HloInstruction* reduce, - const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); - - // Figures out whether `reduce` is a row or column reduction, and which - // dimensions to reduce, and calls either `EmitRowReduction` or - // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the - // input array, which is the operand of the Reduce instruction if unfused or - // of the Fusion instruction if fused. `input_gen` and `init_value_gen` - // generate elements of the input and the initial value. Other parameters mean - // the same as for `HandleReduce`. - // - // Multiple reduces can be emitted in the same loop, assuming they have the - // same input and output shapes, and the same reduce dimensions. - // - // extra_output_gens can contain extra generators for intermediate outputs. - // These must have the same shape as the reduce input as they are computed - // when the reduce inputs are being read. - // - // Prerequisite: `IsReductionToVector(*reduce)` - Status EmitReductionToVector( - KernelThunk* kernel_thunk, HloInstruction* reduce, - const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span dimensions_to_reduce, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); + // Prerequisite: `IsReductionToVector(*unnested_hlo)` + Status EmitReductionToVector(HloInstruction* unnested_hlo); + + // Computes the KernelMappingScheme for the reduce HLO and indicates whether + // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo + // and first_reduce are the same instruction. For a kInput fusion, + // unnested_hlo is the fusion instruction while first_reduce is the first + // reduce op. + std::tuple + ComputeMappingSchemeAndReductionKind(const HloInstruction* unnested_hlo, + const HloInstruction* first_reduce); // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // the process. `scatter` may be fused, scatter indices are taken from @@ -300,20 +249,45 @@ class IrEmitterUnnested : public IrEmitter { const KernelCodeGenerator& kernel_generator, KernelCodegenInfo* kernel_info); void EmitBlock(const TileGenerator& emit_one_tile, - const KernelCodegenInfo* kernel_info, - KernelSupportLibrary& ksl, llvm::Type* index_ty); + KernelCodegenInfo* kernel_info, KernelSupportLibrary* ksl, + llvm::Type* index_ty); // Emits code to process a tensor element in a tile for the given kCopy HLO // that performs a 0-2-1 transpose. void EmitTileElementForCopy(HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num); // Emits code to process a tensor element in a tile for the given kLoop fusion // HLO containing parameters that are 0-2-1 transpose of its outputs. void EmitTileElementForFusion(HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num); + // Emits code to process a tensor element in a tile for the given input hlo + // that is either a unnested kReduce or a kInput fusion. + void EmitTileElementForReduction(HloInstruction* unnested_hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num); + // Prepares for the code generation for a tile block of a reduction kernel. + void EmitPrologueForReduction(HloInstruction* unnested_hlo, + KernelCodegenInfo* kernel_info); + void EmitPrologueForOneReduction(HloInstruction* unnested_hlo, + HloInstruction* reduce_inst, int reduce_idx, + KernelCodegenInfo* kernel_info, + GpuElementalIrEmitter* elemental_emitter, + ShapeIndex output_shape_index); + // Wraps up the code generation for a tile block of a reduction kernel. + void EmitEpilogueForReduction(HloInstruction* unnested_hlo, + KernelCodegenInfo* kernel_info); + // For each reducer, emits the shuffle-down loop to accumulate the partial + // result to the global result. + void EmitFullWarpShuffleDownLoopForAllReduces( + absl::Span reducers, + absl::Span partial_result_addresses); // Generates the IrArray for each input of an hlo and returns a vector that // constains such IrArrays. @@ -346,6 +320,9 @@ class IrEmitterUnnested : public IrEmitter { // Returns a FftThunk that calls cuFFT to implement `inst`. std::unique_ptr BuildFftThunk(const HloInstruction* inst); + // Returns a TriangularSolveThunk that calls cuBlas to implement `inst`. + std::unique_ptr BuildTriangularSolveThunk(const HloInstruction* inst); + // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs // to make sure `inst` outlives the lifetime of the returned Thunk object. std::unique_ptr BuildGemmThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index bd53b90b42d8e657a3ee58e7ca03fb60522aae28..153aab97d9eb971734c5ea95564895631bc2a9fa 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -110,11 +110,9 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, } // Gets the GPU name as it's known to LLVM for a given compute capability. If -// we see an unrecognized compute capability, we return "sm_30". +// we see an unrecognized compute capability, we return "sm_35". static string GetSmName(std::pair compute_capability) { static auto* m = new std::map, int>({ - {{3, 0}, 30}, - {{3, 2}, 32}, {{3, 5}, 35}, {{3, 7}, 37}, {{5, 0}, 50}, @@ -125,8 +123,9 @@ static string GetSmName(std::pair compute_capability) { {{6, 2}, 62}, {{7, 0}, 70}, {{7, 2}, 72}, + {{7, 5}, 75}, }); - int sm_version = 30; + int sm_version = 35; auto it = m->find(compute_capability); if (it != m->end()) { sm_version = it->second; diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 01fddcede64d1bb02ab89db5fc9524893c2d47a4..02e1207f377b8c28bf2566bee8cf3bcbc66794fb 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -67,7 +67,7 @@ int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, } int64 profit = 0; for (auto instr : instr2->operands()) { - if (!IsProfitableOperand(instr) || in_list.count(instr) == 0) { + if (!IsProfitableOperand(instr) || !in_list.contains(instr)) { continue; } profit += ShapeUtil::ByteSizeOf(instr->shape()); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index d16c87ba5c63aa582753fe949e9e39ee2d8b81e5..40b87b16a195564c9b98497f79a70f1db0539d87 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -628,8 +628,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) { p.1 = s32[1]{0} parameter(1) p.2 = f16[1,96,1024]{2,1,0} parameter(2) c.0 = s32[] constant(0) - pad = s32[3]{0} pad(p.1, c.0), padding=0_2 - ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0) } fusion.2 { @@ -638,7 +637,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) { p.2 = f16[1,96,1024]{2,1,0} parameter(2) c.0 = s32[] constant(0) pad = s32[3]{0} pad(p.1, c.0), padding=0_2 - ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0) } ENTRY entry { diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index f3e17d888242a36c268dcbfa0d6530f80cedceb0..6e00e4b4ff8c493f00fae3355215fb13fb5f4f10 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -36,6 +36,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" +#include "tensorflow/compiler/xla/service/convolution_group_converter.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" @@ -50,6 +53,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" @@ -77,6 +81,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/sort_simplifier.h" +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -108,29 +114,58 @@ namespace { namespace tracing = tensorflow::tracing; -// Returns the directory containing nvvm libdevice files. config_cuda_data_dir -// should be equal to config().debug_options().xla_gpu_cuda_data_dir() of the -// HloModule being compiled. -string GetLibdeviceDir(const string& config_cuda_data_dir) { - std::vector potential_libdevice_dirs; - if (!config_cuda_data_dir.empty()) { - potential_libdevice_dirs.push_back(config_cuda_data_dir); - } - potential_libdevice_dirs.push_back(tensorflow::LibdeviceRoot()); - - // Tries all potential libdevice directories in the order they are inserted. - // Returns the first directory that exists in the file system. - for (const string& potential_libdevice_dir : potential_libdevice_dirs) { - if (tensorflow::Env::Default()->IsDirectory(potential_libdevice_dir).ok()) { - VLOG(2) << "Found libdevice dir " << potential_libdevice_dir; - return potential_libdevice_dir; +// Returns a vector of potential locations of the CUDA root directory. +std::vector GetCudaRootCandidates( + const HloModuleConfig& hlo_module_config) { + std::vector potential_cuda_roots = tensorflow::CandidateCudaRoots(); + + // "." is our last resort, even though it probably won't work. + potential_cuda_roots.push_back("."); + + // CUDA location explicitly specified by user via --xla_gpu_cuda_data_dir has + // highest priority. + string xla_gpu_cuda_data_dir = + hlo_module_config.debug_options().xla_gpu_cuda_data_dir(); + if (!xla_gpu_cuda_data_dir.empty()) { + potential_cuda_roots.insert(potential_cuda_roots.begin(), + xla_gpu_cuda_data_dir); + } + return potential_cuda_roots; +} + +void PrintCantFindCudaMessage(absl::string_view msg, + const HloModuleConfig& hlo_module_config) { + LOG(WARNING) << msg; + LOG(WARNING) << "Searched in the following directories:"; + for (const auto& dir : GetCudaRootCandidates(hlo_module_config)) { + LOG(WARNING) << " " << dir; + } + LOG(WARNING) + << "You can choose the search directory by setting xla_gpu_cuda_data_dir " + "in HloModule's DebugOptions. For most apps, setting the environment " + "variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work."; +} + +// Returns the directory containing nvvm libdevice files. +string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { + const auto& candidate_dirs = GetCudaRootCandidates(hlo_module_config); + for (const string& cuda_root : candidate_dirs) { + string libdevice_dir = + tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); + VLOG(2) << "Looking for libdevice at " << libdevice_dir; + if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { + VLOG(2) << "Found libdevice dir " << libdevice_dir; + return libdevice_dir; } - VLOG(2) << "Unable to find potential libdevice dir " - << potential_libdevice_dir; } + PrintCantFindCudaMessage( + "Can't find directory containing CUDA libevice. This may result in " + "compilation or runtime failures, if the program we try to run uses " + "routines from libdevice.", + hlo_module_config); - LOG(WARNING) << "Unable to find libdevice dir. Using '.'"; - // Last resort: maybe in the current folder. + // GetCudaRotCandidates always inclues ".", but but if everything fails, we + // return it anyway. Better than returning the empty string. return "."; } @@ -145,6 +180,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + pipeline.AddPass(); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), @@ -152,6 +188,16 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // TODO(b/64094172): make Call work on GPU instead of inlining. pipeline.AddPass(); + auto cost_model = [](HloInstruction* conv) { + // We need a cost model for GPUs. Currently, do nothing. + return false; + }; + pipeline.AddPass(false); + pipeline.AddPass( + cost_model, + /*convert_batch_groups_only=*/true); + // Expand the sort op to support stable sorting if required. + pipeline.AddPass(); // Convert BF16 operations to F32 operations so that the GPU backend can // support BF16 operations without directly implementing a BF16 lowering for // most ops. @@ -180,10 +226,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // elimination has to come after that pass. pipeline.AddPass(); - AlgebraicSimplifierOptions options( - [](const Shape&, const Shape&) { return false; }); - options.set_enable_permutation_sort_replacement(true); + AlgebraicSimplifierOptions options; pass.AddPass(options); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -252,12 +297,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. - AlgebraicSimplifierOptions options( - /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { - return true; - }); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); - options.set_enable_permutation_sort_replacement(true); pipeline.AddPass>(options); // Choose the fastest algorithm for each conv. @@ -361,6 +402,7 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); return pipeline.Run(hlo_module).status(); } @@ -478,14 +520,19 @@ void WarnIfBadDriverJITVersion() { // Compiles the given PTX string using ptxas and returns the resulting machine // code (i.e. a cubin) as a byte array. -StatusOr> CompilePtx(const string& ptx, int cc_major, - int cc_minor, - bool disable_ptx_optimizations) { +StatusOr> CompilePtx( + const string& ptx, int cc_major, int cc_minor, + const HloModuleConfig& hlo_module_config) { tracing::ScopedActivity activity("Compile PTX", /*is_expensive=*/true); - const string ptxas_path = - tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); - VLOG(2) << "Checking ptxas at " << ptxas_path; auto env = tensorflow::Env::Default(); + string ptxas_path; + for (const string& cuda_root : GetCudaRootCandidates(hlo_module_config)) { + ptxas_path = tensorflow::io::JoinPath(cuda_root, "bin", "ptxas"); + VLOG(2) << "Looking for ptxas at " << ptxas_path; + if (env->FileExists(ptxas_path).ok()) { + break; + } + } TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); VLOG(2) << "Using ptxas at " << ptxas_path; @@ -520,7 +567,7 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } - if (disable_ptx_optimizations) { + if (hlo_module_config.debug_options().xla_gpu_disable_ptxas_optimizations()) { ptxas_args.push_back("-O0"); } ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); @@ -685,12 +732,8 @@ StatusOr> NVPTXCompiler::RunBackend( // Find the directory containing libdevice. To avoid searching for it every // time, we have a one-element cache, keyed on the module's config's // cuda_data_dir. - const auto& config_cuda_data_dir = - module->config().debug_options().xla_gpu_cuda_data_dir(); - if (cached_libdevice_dir_.empty() || - cached_cuda_data_dir_ != config_cuda_data_dir) { - cached_cuda_data_dir_ = config_cuda_data_dir; - cached_libdevice_dir_ = GetLibdeviceDir(config_cuda_data_dir); + if (cached_libdevice_dir_.empty()) { + cached_libdevice_dir_ = GetLibdeviceDir(module->config()); } libdevice_dir = cached_libdevice_dir_; } @@ -743,9 +786,8 @@ StatusOr> NVPTXCompiler::RunBackend( } } - const std::vector cubin = CompilePtxOrGetCachedResult( - ptx, cc_major, cc_minor, - module->config().debug_options().xla_gpu_disable_ptxas_optimizations()); + const std::vector cubin = + CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor, module->config()); auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), @@ -756,14 +798,19 @@ StatusOr> NVPTXCompiler::RunBackend( std::unique_ptr profile_index_map; std::unique_ptr profile_printer; - if (module->config().hlo_profiling_enabled()) { + if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) { HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); cost_analysis.set_bytes_per_second( stream_exec->GetDeviceDescription().memory_bandwidth()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); - profile_index_map = absl::make_unique(*module); - profile_printer = CreateHloProfilePrinterData( - *profile_index_map, cost_analysis, entry_computation->name()); + VLOG(1) << "HLO memory read+written: " + << tensorflow::strings::HumanReadableNumBytes( + cost_analysis.bytes_accessed()); + if (module->config().hlo_profiling_enabled()) { + profile_index_map = absl::make_unique(*module); + profile_printer = CreateHloProfilePrinterData( + *profile_index_map, cost_analysis, entry_computation->name()); + } } auto* gpu_executable = new GpuExecutable( @@ -779,7 +826,7 @@ StatusOr> NVPTXCompiler::RunBackend( std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( const string& ptx, int cc_major, int cc_minor, - bool disable_ptx_optimizations) { + const HloModuleConfig& hlo_module_config) { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompilePtxOrGetCachedResult"); tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true); bool inserted; @@ -807,8 +854,8 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( if (inserted) { CHECK(!cache_value->compilation_done); if (!ptx.empty()) { - StatusOr> maybe_cubin = CompilePtx( - *cache_ptx, cc_major, cc_minor, disable_ptx_optimizations); + StatusOr> maybe_cubin = + CompilePtx(*cache_ptx, cc_major, cc_minor, hlo_module_config); if (maybe_cubin.ok()) { cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); VLOG(2) << "Compiled PTX size:" << ptx.size() @@ -827,10 +874,11 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( log_warning = !warning_done.exchange(true); } if (log_warning) { - LOG(WARNING) - << "Failed to compile ptx to cubin. Will attempt to let " - "GPU driver compile the ptx. " - << maybe_cubin.status(); + PrintCantFindCudaMessage( + "Can't find ptxas binary. Will back to the GPU driver " + "for PTX -> sass compilation. This is OK so long as you don't " + "see a warning below about an out-of-date driver version.", + hlo_module_config); } // We're going to use the driver to JIT our PTX->SASS, so warn if diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index be5e31a50112686841e6f18b76f382a56e61bafc..b2077f42fd097330703fde063d80a20704fa48e2 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -99,7 +99,7 @@ class NVPTXCompiler : public LLVMCompiler { // compiled cubin. If compilation was unsuccessful, returns an empty vector. std::vector CompilePtxOrGetCachedResult( const string& ptx, int cc_major, int cc_minor, - bool disable_ptx_optimizations); + const HloModuleConfig& hlo_module_config); // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} // -> cubin so we don't recompile the same ptx twice. This is important for diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 8154d75d23a6d49153ccb6824402aff73f365617..cb012649200c6386d3ae25d088aa3b16bd40be82 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index 375f68a15957936151aee068582a714b62694af2..bfed4f5230dfe37bca48560ce83a2dd82c8950a4 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -39,6 +39,25 @@ std::ostream& operator<<(std::ostream& out, return out; } +int64 ThreadsPerBlockLimit(const se::DeviceDescription& device_desc) { + int64 threads_per_block = device_desc.threads_per_block_limit(); + if (threads_per_block == 0) { + static std::atomic log_count{0}; + if (log_count.fetch_add(1) < 8) { + LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " + "without full information about its capabilities. " + "StreamExecutor's PopulateDeviceDescription should be " + "updated for this device."; + } + threads_per_block = device_desc.threads_per_warp(); + if (threads_per_block == 0) { + // Fall back to *something* if we can't even get num threads per warp. + threads_per_block = 32; + } + } + return threads_per_block; +} + // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions( const Shape& shape, const se::DeviceDescription& device_desc, @@ -62,21 +81,7 @@ LaunchDimensions CalculateLaunchDimensions( // // * = - int64 threads_per_block = device_desc.threads_per_block_limit(); - if (threads_per_block == 0) { - static std::atomic log_count{0}; - if (log_count.fetch_add(1) < 8) { - LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " - "without full information about its capabilities. " - "StreamExecutor's PopulateDeviceDescription should be " - "updated for this device."; - } - threads_per_block = device_desc.threads_per_warp(); - if (threads_per_block == 0) { - // Fall back to *something* if we can't even get num threads per warp. - threads_per_block = 32; - } - } + int64 threads_per_block = ThreadsPerBlockLimit(device_desc); if (num_elements < threads_per_block) { threads_per_block = num_elements; diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 02471129e004b4876ce20a62cade34060c65b478..eb41dcccb938ccc088c2371def96ca73276771ab 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -57,6 +57,9 @@ class LaunchDimensions { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims); +// Returns the maximum number of threads per block allowed by the device. +int64 ThreadsPerBlockLimit(const se::DeviceDescription& device_desc); + LaunchDimensions CalculateLaunchDimensions( const Shape& shape, const se::DeviceDescription& device_desc, int unroll_factor = 1); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 4775baf44aecfe6adaf2bf0d2791595436635b16..1dedbd3befce6e2ceb06126d83a061207a90dd8f 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -25,7 +26,7 @@ namespace xla { namespace gpu { bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const { - return hlo_to_stream_number_.count(&hlo); + return hlo_to_stream_number_.contains(&hlo); } int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const { @@ -98,10 +99,10 @@ int ComputeStreamToAssign( // greedy approach. First, we compute as forbidden_stream_numbers the // streams assigned to GEMMs that are concurrent with `hlo`. Then, we assign // `hlo` a different stream. - std::set forbidden_stream_numbers; + absl::flat_hash_set forbidden_stream_numbers; for (const auto* seen_gemm : seen_gemms) { int stream_num = stream_assignment.StreamNumberForHlo(*seen_gemm); - if (!forbidden_stream_numbers.count(stream_num) && + if (!forbidden_stream_numbers.contains(stream_num) && CanRunConcurrently(*seen_gemm, hlo, reachability)) { forbidden_stream_numbers.insert(stream_num); } @@ -109,7 +110,7 @@ int ComputeStreamToAssign( for (int stream_num = 0; stream_num < stream_assignment.StreamCount(); ++stream_num) { - if (!forbidden_stream_numbers.count(stream_num)) { + if (!forbidden_stream_numbers.contains(stream_num)) { return stream_num; } } diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 1fc46bafa10e7ba6c896f081d5c836bd400886c9..92e4d6dbbc1bd564657f8a5de09d23d5ae81a93e 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index a1ed8499040359fe7265a7317b0577a990a2234c..d33e9cf714ee3810b1fb2fa8c05c3ed399d27bfb 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index a302b582ede3723acd118d2e4a4bb3efdf7a4d0b..869724db601b2d5e4ed6d3c7bf3e10a748433146 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -65,7 +65,7 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy -; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -91,7 +91,7 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy -; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK-NOT: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -118,7 +118,7 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -152,7 +152,7 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -187,13 +187,13 @@ TEST_F(GpuKernelTilingTest, CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK-NOT: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); } -TEST_F(GpuKernelTilingTest, FusionTransposeWithReverseNotTiled) { +TEST_F(GpuKernelTilingTest, TransposedInputWithUserReverseNotTiled) { const char *const kHloString = R"( HloModule FusionTransposeWithReverseNotTiled fused_computation.1 { @@ -214,12 +214,203 @@ TEST_F(GpuKernelTilingTest, FusionTransposeWithReverseNotTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK-NOT: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); } +TEST_F(GpuKernelTilingTest, TransposedInputWithUserBitcastNotTiled) { + const char *const kHloString = R"( + HloModule TransposedInputWithUserBitcast + + fused_computation { + param_0 = f32[20,20]{1,0} parameter(0) + ROOT bitcast = f32[20,20]{0,1} bitcast(param_0) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{1,0} parameter(0) + ROOT fusion = f32[20,20]{0,1} fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + // Check that a call to llvm.nvvm.barrier0 is not generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK-NOT: call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, TransposedInputWithoutUnsafeUseTiled) { + const char *const kHloString = R"( + HloModule TwoTransposedInputs + + fused_computation { + param_0 = f32[64,64]{1,0} parameter(0) + param_1 = f32[64,64]{1,0} parameter(1) + bitcast = f32[64,64]{0,1} bitcast(param_0) + copy = f32[64,64]{0,1} copy(param_1) + ROOT tuple = (f32[64,64]{0,1}, f32[64,64]{0,1}) tuple(bitcast, copy) + } + + ENTRY kernel_entry { + parameter.0 = f32[64,64]{1,0} parameter(0) + parameter.1 = f32[64,64]{1,0} parameter(1) + ROOT fusion = (f32[64,64]{0,1}, f32[64,64]{0,1}) + fusion(parameter.0, parameter.1), + kind=kLoop, calls=fused_computation + })"; + + // Check that a call to llvm.nvvm.barrier0 is generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, ColumnReductionWithPowerOf2OutputElementsUnrolled) { + const char *const kHloString = R"( + HloModule column_reduce_powerof2 + + reduction { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + ENTRY kernel_entry { + constant0 = f32[] constant(0) + arg1 = f16[1024,512]{1,0} parameter(0) + arg1_conv = f32[1024,512]{1,0} convert(arg1) + ROOT reduce = f32[512]{0} reduce(arg1_conv, constant0), dimensions={0}, to_apply=reduction + })"; + + // Check that two calls to llvm.nvvm.atomic are generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK-NOT: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +} + +TEST_F(GpuKernelTilingTest, + ColumnReductionWithInputLargerThenReduceInputNotUnrolled) { + const char *const kHloString = R"( + HloModule larger_than_reduce_input_parameter + + reduction22 { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + fused_computation { + constant0 = f32[] constant(0) + arg.1 = f16[1024,512]{1,0} parameter(0) + arg.2 = f16[1027,513]{1,0} parameter(1) + arg1.conv = f32[1024,512]{1,0} convert(arg.1) + arg2.conv = f32[1027,513]{1,0} convert(arg.2) + slice2 = f32[1024,512]{1,0} slice(arg2.conv), slice={[2:1026], [1:513]} + add2 = f32[1024,512]{1,0} add(arg1.conv, slice2) + ROOT reduce = f32[512]{0} reduce(add2, constant0), dimensions={0}, + to_apply=reduction22 + } + + ENTRY kernel_entry { + arg1 = f16[1024,512]{1,0} parameter(0) + arg2 = f16[1027,513]{1,0} parameter(1) + ROOT fusion = f32[512]{0} fusion(arg1, arg2), kind=kInput, + calls=fused_computation + })"; + + // Check that one call to llvm.nvvm.atomic is generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK-NOT: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +} + +TEST_F(GpuKernelTilingTest, ColumnReductionMOFUnrolled) { + const char *const kHloString = R"( + HloModule column_reduce_powerof2_mof + + reduction22 { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + fused_computation { + constant0 = f32[] constant(0) + arg.1 = f16[1024,512]{1,0} parameter(0) + arg.2 = f16[1024,512]{1,0} parameter(1) + arg1.conv = f32[1024,512]{1,0} convert(arg.1) + arg2.conv = f32[1024,512]{1,0} convert(arg.2) + reduce1 = f32[512]{0} reduce(arg1.conv, constant0), dimensions={0}, + to_apply=reduction22 + reduce2 = f32[512]{0} reduce(arg2.conv, constant0), dimensions={0}, + to_apply=reduction22 + add = f32[1024,512]{1,0} add(arg1.conv, arg2.conv) + ROOT tuple = (f32[512]{0}, f32[512]{0}, f32[1024,512]{1,0}) + tuple(reduce1, reduce2, add) + } + + ENTRY kernel_entry { + arg1 = f16[1024,512]{1,0} parameter(0) + arg2 = f16[1024,512]{1,0} parameter(1) + ROOT fusion = (f32[512]{0}, f32[512]{0}, f32[1024,512]{1,0}) + fusion(arg1, arg2), kind=kInput, calls=fused_computation + })"; + + // Check that four calls to llvm.nvvm.atomic are generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK-NOT: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +} } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc index f8120a5fa00ce38644cd85c54d5ef65701be1eda..06b06a5b1ee1fb9996be3ebe326893c4160a7e29 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" @@ -43,7 +42,7 @@ class InfeedTest : public ClientLibraryTestBase { ASSERT_IS_OK(client_->TransferToInfeed(literal)); XlaBuilder builder(TestName()); Infeed(&builder, literal.shape()); - if (ShapeUtil::IsTuple(literal.shape())) { + if (literal.shape().IsTuple()) { // TODO(b/30609564): Use ComputeAndCompareLiteral instead. ComputeAndCompareTuple(&builder, literal, {}); } else { diff --git a/tensorflow/compiler/xla/service/gpu/thunk.cc b/tensorflow/compiler/xla/service/gpu/thunk.cc index c78605cebbc671272b8df9faf0e0cc54be2f5b1c..a677617727c04811584cbaa295d164ed27273bb2 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk.cc @@ -48,6 +48,8 @@ std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { return os << "kOutfeed"; case Thunk::kSequential: return os << "kSequential"; + case Thunk::kTriangularSolve: + return os << "kTriangularSolve"; case Thunk::kTuple: return os << "kTuple"; case Thunk::kWhile: diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index e68bee035a029178844282995429eaa960cc4817..bc69af897a01775d2d33d46067464b10e049f3e1 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -56,6 +56,7 @@ class Thunk { kMemzero, kOutfeed, kSequential, + kTriangularSolve, kTuple, kWhile, }; diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index 6b2d76764a077dc6cfa3f9ddc6e525ab330323be..25bad67bab9375559c431466571c62acd0452b01 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -14,17 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/map_util.h" namespace xla { namespace gpu { void ThunkSchedule::AddDependenciesOnTransitiveOperands( const Thunk& thunk, const HloInstruction& operand, - const std::unordered_map& hlo_to_thunk) { - if (hlo_to_thunk.count(&operand)) { + const absl::flat_hash_map& hlo_to_thunk) { + if (hlo_to_thunk.contains(&operand)) { // If `operand` is mapped to a thunk, adds `operand` to `thunk`'s dependency // list if `operand` is assigned to a different stream. As an optimization, // we skip `operand`'s operands because `operand` depends on them already. @@ -48,14 +50,14 @@ ThunkSchedule::ThunkSchedule( const std::vector& hlo_total_order) : thunks_(std::move(thunks)), stream_assignment_(std::move(stream_assignment)) { - std::unordered_map hlo_to_thunk; + absl::flat_hash_map hlo_to_thunk; for (const auto& thunk : *thunks_) { InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get()); } for (HloInstruction* hlo : hlo_total_order) { - if (hlo_to_thunk.count(hlo)) { - thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo)); + if (Thunk** thunk = tensorflow::gtl::FindOrNull(hlo_to_thunk, hlo)) { + thunk_total_order_.push_back(*thunk); } } @@ -106,7 +108,7 @@ void ThunkSchedule::RemoveRedundantDependencyEdges() { // redundant dependency edge. Array2D last_dependency(stream_count, stream_count, -1); for (const Thunk* dst : thunk_total_order_) { - if (!depends_on_.count(dst)) { + if (!depends_on_.contains(dst)) { continue; } @@ -134,7 +136,7 @@ void ThunkSchedule::RemoveRedundantDependencyEdges() { const std::list& ThunkSchedule::DependsOn( const Thunk* thunk) const { - if (depends_on_.count(thunk)) { + if (depends_on_.contains(thunk)) { return FindOrDie(depends_on_, thunk); } else { return empty_thunk_list_; diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h index 43b628a1baf0e79a3197f3cfad3547991642eaed..549378debd52417252724a5d8a6f4d24f2ad0369 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -54,7 +56,9 @@ class ThunkSchedule { // Thunks that `thunk` depends on. const std::list& DependsOn(const Thunk* thunk) const; // Whether `thunk` is depended by another thunk. - bool Depended(const Thunk* thunk) const { return depended_by_.count(thunk); } + bool Depended(const Thunk* thunk) const { + return depended_by_.contains(thunk); + } // Delegates to StreamAssignment. int StreamCount() const { return stream_assignment_->StreamCount(); } @@ -75,13 +79,13 @@ class ThunkSchedule { // thunk.hlo_instruction(). void AddDependenciesOnTransitiveOperands( const Thunk& thunk, const HloInstruction& operand, - const std::unordered_map& hlo_to_thunk); + const absl::flat_hash_map& hlo_to_thunk); std::unique_ptr thunks_; std::vector thunk_total_order_; - std::unordered_map> depends_on_; - std::set depended_by_; + absl::flat_hash_map> depends_on_; + absl::flat_hash_set depended_by_; std::list empty_thunk_list_; std::unique_ptr stream_assignment_; diff --git a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..5200a2af412979c7e38d95c5a9bd5bc2ab64f086 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc @@ -0,0 +1,149 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/blas.h" +#include "tensorflow/stream_executor/device_memory.h" + +namespace xla { +namespace gpu { + +TriangularSolveThunk::TriangularSolveThunk( + const TriangularSolveOptions& options, + const BufferAllocation::Slice& a_buffer, + const BufferAllocation::Slice& b_buffer, PrimitiveType type, + int64 batch_size, int64 m, int64 n, int64 a_batch_stride, + int64 b_batch_stride, const HloInstruction* hlo) + : Thunk(Kind::kTriangularSolve, hlo), + uplo_(options.lower() ? se::blas::UpperLower::kLower + : se::blas::UpperLower::kUpper), + side_(options.left_side() ? se::blas::Side::kLeft + : se::blas::Side::kRight), + unit_diagonal_(options.unit_diagonal() ? se::blas::Diagonal::kUnit + : se::blas::Diagonal::kNonUnit), + a_buffer_(a_buffer), + b_buffer_(b_buffer), + type_(type), + batch_size_(batch_size), + m_(m), + n_(n), + a_batch_stride_(a_batch_stride), + b_batch_stride_(b_batch_stride) { + transpose_a_ = [&] { + switch (options.transpose_a()) { + case TriangularSolveOptions::NO_TRANSPOSE: + return se::blas::Transpose::kNoTranspose; + case TriangularSolveOptions::TRANSPOSE: + return se::blas::Transpose::kTranspose; + case TriangularSolveOptions::ADJOINT: + return se::blas::Transpose::kConjugateTranspose; + default: + LOG(ERROR) << "Invalid triangular solve transpose value " + << options.transpose_a(); + return se::blas::Transpose::kNoTranspose; + } + }(); +} + +Status TriangularSolveThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo_) + << " side=" << se::blas::SideString(side_) + << " diagonal=" << se::blas::DiagonalString(unit_diagonal_) + << " batch_size=" << batch_size_ << " m=" << m_ << " n=" << n_ + << " a_batch_stride=" << a_batch_stride_ + << " b_batch_stride=" << b_batch_stride_; + + const int lda = side_ == se::blas::Side::kLeft ? m_ : n_; + const int ldb = m_; + + char* a_base = static_cast( + buffer_allocations.GetDeviceAddress(a_buffer_).opaque()); + char* b_base = static_cast( + buffer_allocations.GetDeviceAddress(b_buffer_).opaque()); + for (int64 i = 0; i < batch_size_; ++i) { + bool launch_ok; + se::DeviceMemoryBase a_data = + se::DeviceMemoryBase(a_base + i * a_batch_stride_, a_batch_stride_); + se::DeviceMemoryBase b_data = + se::DeviceMemoryBase(b_base + i * b_batch_stride_, b_batch_stride_); + switch (type_) { + case F32: { + se::DeviceMemory b_data_typed(b_data); + launch_ok = stream + ->ThenBlasTrsm(side_, uplo_, transpose_a_, + unit_diagonal_, m_, n_, /*alpha=*/1.0f, + se::DeviceMemory(a_data), lda, + &b_data_typed, ldb) + .ok(); + break; + } + case F64: { + se::DeviceMemory b_data_typed(b_data); + launch_ok = stream + ->ThenBlasTrsm(side_, uplo_, transpose_a_, + unit_diagonal_, m_, n_, /*alpha=*/1.0, + se::DeviceMemory(a_data), lda, + &b_data_typed, ldb) + .ok(); + break; + } + case C64: { + se::DeviceMemory> b_data_typed(b_data); + launch_ok = + stream + ->ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_, + n_, /*alpha=*/1.0f, + se::DeviceMemory>(a_data), + lda, &b_data_typed, ldb) + .ok(); + break; + } + case C128: { + se::DeviceMemory> b_data_typed(b_data); + launch_ok = + stream + ->ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_, + n_, /*alpha=*/1.0, + se::DeviceMemory>(a_data), + lda, &b_data_typed, ldb) + .ok(); + break; + } + default: + return InvalidArgument("Invalid type for triangular solve %d", type_); + } + if (!launch_ok) { + return InternalError("Unable to launch triangular solve for thunk %p", + this); + } + } + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..c947162ea32f197f808d099859eadbbc55a65ab1 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h @@ -0,0 +1,75 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRIANGULAR_SOLVE_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRIANGULAR_SOLVE_THUNK_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/blas.h" + +namespace xla { +namespace gpu { + +// This class stores everything that StreamExecutor needs to launch a triangular +// solve (BlasTrsm). It is generated by IrEmitter. +// +// Thread-compatible. +class TriangularSolveThunk : public Thunk { + public: + TriangularSolveThunk(const TriangularSolveOptions& options, + const BufferAllocation::Slice& a_buffer, + const BufferAllocation::Slice& b_buffer, + PrimitiveType type, int64 batch_size, int64 m, int64 n, + int64 a_batch_stride, int64 b_batch_stride, + const HloInstruction* hlo); + + TriangularSolveThunk(const TriangularSolveThunk&) = delete; + TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete; + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream, + HloExecutionProfiler* profiler) override; + + private: + const se::blas::UpperLower uplo_; + const se::blas::Side side_; + const se::blas::Diagonal unit_diagonal_; + se::blas::Transpose transpose_a_; + + const BufferAllocation::Slice a_buffer_; + const BufferAllocation::Slice b_buffer_; + + const PrimitiveType type_; + const int64 batch_size_; + const int64 m_; + const int64 n_; + const int64 a_batch_stride_; + const int64 b_batch_stride_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRIANGULAR_SOLVE_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc index c552c2925497f1c4808d74a615d35cdbeeba1858..bbbcc2dbb0f71d08462a1aad6d97e7fd07b2a1fb 100644 --- a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc +++ b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/xfeed_queue.h b/tensorflow/compiler/xla/service/gpu/xfeed_queue.h index dd46ff433ba0ad6bfa3999b96845fdaebe148aca..167c038420a64d9fa29746ed3fe349620e08e6ff 100644 --- a/tensorflow/compiler/xla/service/gpu/xfeed_queue.h +++ b/tensorflow/compiler/xla/service/gpu/xfeed_queue.h @@ -47,6 +47,10 @@ class XfeedQueue { // Blocks until the queue is non-empty, then returns the buffer at the head of // the queue. BufferType BlockingGetNextDestination() { + for (const auto& callback : before_get_next_dest_callbacks_) { + callback(); + } + bool became_empty; BufferType current_buffer; { @@ -69,6 +73,10 @@ class XfeedQueue { void RegisterOnEmptyCallback(std::function callback) { on_empty_callbacks_.push_back(std::move(callback)); } + void RegisterBeforeGetNextDestinationCallback( + std::function callback) { + before_get_next_dest_callbacks_.push_back(std::move(callback)); + } private: tensorflow::mutex mu_; @@ -82,6 +90,11 @@ class XfeedQueue { // List of callbacks which will be called when 'enqueued_buffers_' becomes // empty. std::vector> on_empty_callbacks_; + + // List of callbacks which will be called before BlockingGetNextDestination() + // is called. This lets you e.g. call EnqueueDestination() for each call to + // BlockingGetNextDestination(). + std::vector> before_get_next_dest_callbacks_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 9220865867b770eebfb1ada8f31a5d24693a4b8d..4fca981c6a59cdb91a997e6a887fd26472c1a10a 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -199,7 +199,7 @@ Status HeapSimulator::RunComputation( // If the buffer has no users and isn't an entry parameter or output, it // must be a dead value. - if (live_buffers.count(buffer) == 0) { + if (!live_buffers.contains(buffer)) { dead_buffers_to_free.push_back(buffer); } } @@ -225,10 +225,10 @@ Status HeapSimulator::RunComputation( } } // Sort to get a deterministic iteration order. - std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(), - [](const BufferValue* x, const BufferValue* y) { - return x->id() < y->id(); - }); + absl::c_sort(operand_buffers_to_free, + [](const BufferValue* x, const BufferValue* y) { + return x->id() < y->id(); + }); // Allocate buffers defined by this instruction. This is the latest point // that we can allocate; right before the buffer is first used. This must @@ -253,7 +253,7 @@ Status HeapSimulator::RunComputation( bool shared = false; if (options_.may_reuse_operand_buffers) { for (const BufferValue* operand_buffer : operand_buffers_to_free) { - if (reused_buffers.count(operand_buffer) != 0) { + if (reused_buffers.contains(operand_buffer)) { continue; } if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && @@ -335,10 +335,9 @@ Status HeapSimulator::RunComputation( to_free.push_back(buffer); } - std::sort(to_free.begin(), to_free.end(), - [](const BufferValue* x, const BufferValue* y) { - return x->id() < y->id(); - }); + absl::c_sort(to_free, [](const BufferValue* x, const BufferValue* y) { + return x->id() < y->id(); + }); for (const BufferValue* buffer : to_free) { VLOG(3) << "Freeing pending: " << buffer->ToString(); Free(buffer, root); @@ -374,15 +373,15 @@ bool HeapSimulator::IgnoreBuffer(const BufferValue* buffer) const { return true; } return options_.buffers_to_assign != nullptr && - options_.buffers_to_assign->count(buffer) == 0; + !options_.buffers_to_assign->contains(buffer); } // Alloc always calls the underlying heap algorithm. void HeapSimulator::Alloc(const BufferValue* buffer, const HloInstruction* instruction) { - CHECK(allocated_buffers_.count(buffer) == 0) + CHECK(!allocated_buffers_.contains(buffer)) << "Alloc called on allocated buffer: " << *buffer; - CHECK(freed_buffers_.count(buffer) == 0) + CHECK(!freed_buffers_.contains(buffer)) << "Alloc called on freed buffer: " << *buffer; allocated_buffers_.insert(buffer); @@ -411,9 +410,9 @@ void HeapSimulator::Free(const BufferValue* buffer, buffer = group->canonical; } - CHECK(allocated_buffers_.count(buffer) > 0) + CHECK(allocated_buffers_.contains(buffer)) << "Free called on non-allocated buffer: " << *buffer; - CHECK(freed_buffers_.count(buffer) == 0) + CHECK(!freed_buffers_.contains(buffer)) << "Free called on freed buffer: " << *buffer; freed_buffers_.insert(buffer); @@ -433,11 +432,11 @@ void HeapSimulator::ShareBuffer(const BufferValue* buffer, const HloInstruction* instruction) { CHECK_LE(size_fn_(*buffer), size_fn_(*shared)) << "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared; - CHECK(allocated_buffers_.count(buffer) == 0) + CHECK(!allocated_buffers_.contains(buffer)) << "ShareBuffer called on allocated buffer: " << *buffer; - CHECK(freed_buffers_.count(buffer) == 0) + CHECK(!freed_buffers_.contains(buffer)) << "ShareBuffer called on freed buffer: " << *buffer; - CHECK(freed_buffers_.count(shared) == 0) + CHECK(!freed_buffers_.contains(shared)) << "ShareBuffer called on freed shared buffer: " << *shared; const BufferValue* canonical = nullptr; @@ -452,7 +451,7 @@ void HeapSimulator::ShareBuffer(const BufferValue* buffer, } else { // The 'shared' buffer doesn't have a group; it must be the canonical. Add // both 'buffer' and 'shared' to a new group. - CHECK(allocated_buffers_.count(shared) > 0) + CHECK(allocated_buffers_.contains(shared)) << "ShareBuffer called on non-allocated shared buffer: " << *shared; auto group = std::make_shared(); canonical = shared; @@ -596,7 +595,7 @@ void DecreasingSizeRunsHeap::CallAndDrainRun() { } // Call ops in the run sorted by decreasing size, breaking ties by buffer id. - std::sort(run_.begin(), run_.end(), [](const Op& a, const Op& b) { + absl::c_sort(run_, [](const Op& a, const Op& b) { if (a.size != b.size) { return a.size > b.size; } @@ -866,23 +865,23 @@ HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() { for (auto& entry : buffer_intervals_) { sorted_buffer_intervals.push_back(entry.second); } - std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(), - [](const BufferInterval& x, const BufferInterval& y) { - if (x.size != y.size) { - return x.size > y.size; - } - if (x.end - x.start != y.end - y.start) { - return x.end - x.start > y.end - y.start; - } - return x.buffer->id() < y.buffer->id(); - }); + absl::c_sort(sorted_buffer_intervals, + [](const BufferInterval& x, const BufferInterval& y) { + if (x.size != y.size) { + return x.size > y.size; + } + if (x.end - x.start != y.end - y.start) { + return x.end - x.start > y.end - y.start; + } + return x.buffer->id() < y.buffer->id(); + }); BufferIntervalTree interval_tree(sorted_buffer_intervals.size()); for (auto& buffer_interval : sorted_buffer_intervals) { auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime( buffer_interval.start, buffer_interval.end); - std::sort( - chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(), + absl::c_sort( + chunks_overlapping_in_time, [](const Chunk& x, const Chunk& y) { return x.offset < y.offset; }); // Find the minimum free chunk that can hold this buffer. diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index dbbf43082f2c1d21f5ef42f53804bf0969903a58..3e0631aeb4aa374cb5748650e1c7529e26e10b34 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -158,7 +158,7 @@ class HeapSimulator { void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, const BufferValue* buffer, const HloInstruction* instruction, - const BufferValue* shared_with_canonical); + const BufferValue* share_with_canonical); // Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap, // in which case we are calculating the same allocs/frees twice in the diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 414c63271245315f037d04924c9291a9cd5b7a77..ae9e3169fd9b7a4655ab91ffb1589b845402ba8d 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 58 +// Next ID: 62 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -82,6 +82,8 @@ message HloInstructionProto { // it will use a default value of 1. int64 feature_group_count = 50; + int64 batch_group_count = 58; + // Describes the [begin, end) index range and stride for slices. message SliceDimensions { int64 start = 1; @@ -166,13 +168,16 @@ message HloInstructionProto { // Cross replica op fields. repeated ReplicaGroup replica_groups = 49; int64 all_reduce_id = 45; - string cross_replica_sum_barrier = 46; + string all_reduce_barrier = 46; // Whether this Send/Recv instruction transfers data to/from the host. Only // present for Send and Recv instructions and their SendDone and RecvDone // partners. bool is_host_transfer = 47; + // Whether this Sort instruction should be stable. + bool is_stable = 60; + xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; // Precision configuration for the instruction. Has backend-specific meaning. @@ -191,6 +196,12 @@ message HloInstructionProto { // operand. bool constrain_layout = 56; repeated xla.ShapeProto operand_shapes_with_layout = 57; + + // Options for TriangularSolve + xla.TriangularSolveOptions triangular_solve_options = 59; + + // Describes how parameters behave with regards to replicas. + xla.ParameterReplication parameter_replication = 61; } // Serialization of HloComputation. @@ -227,6 +238,18 @@ message HloScheduleProto { } message HloInputOutputAliasProto { + enum Kind { + // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3 + // behavior and missing has_*() APIs. + UNDEFINED_ALIAS = 0; + // An alias setup by the user as must alias. A use setting USER_ALIAS is + // expecting the designed output to be dropped over the given input + // parameter number+index. + USER_ALIAS = 1; + // An alias setup by the compiler as part of its optimizations. + SYSTEM_ALIAS = 2; + } + // The following proto describes a pair of aliased an input // (described by parameter number and a ShapeIndex of the parameter) // and an output (described by a ShapeIndex of the root @@ -247,6 +270,8 @@ message HloInputOutputAliasProto { int64 parameter_number = 2; // ShapeIndex of the parameter instruction. repeated int64 parameter_shape_index = 3; + // The kind of alias to be setup. + Kind kind = 4; } repeated AliasEntryProto entries = 1; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index cf8e6594cbe5ffd28ca75dd5006e8817f1e8581c..e511f1951c5dd07ebb64fa38fd5b7f6a0e87b429 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -117,7 +117,7 @@ class BufferValueMap { for (const auto& pair : buffers_) { buffer_numbers.push_back(pair.first); } - std::sort(buffer_numbers.begin(), buffer_numbers.end()); + absl::c_sort(buffer_numbers); return buffer_numbers; } @@ -176,13 +176,12 @@ class BufferValueMap { const HloValue& value, std::vector* aliased_buffers) { // Get parameter value from an aliased_input object. const auto get_parameter_value = - [this](const std::pair& aliased_input) + [this](const HloInputOutputAliasConfig::Alias& aliased_input) -> const HloValue& { - int64 param_number = aliased_input.first; - const ShapeIndex& param_index = aliased_input.second; return dataflow_.GetUniqueValueAt( - module_->entry_computation()->parameter_instruction(param_number), - param_index); + module_->entry_computation()->parameter_instruction( + aliased_input.parameter_number), + aliased_input.parameter_index); }; // If the value shows up in a root instruction, alias it with parameter @@ -319,7 +318,7 @@ class BufferValueMap { ComputeWhileAliasedBuffers(value, &aliased_buffers); ComputeConditionalAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. - std::sort(aliased_buffers.begin(), aliased_buffers.end()); + absl::c_sort(aliased_buffers); aliased_buffers.erase( std::unique(aliased_buffers.begin(), aliased_buffers.end()), aliased_buffers.end()); @@ -367,7 +366,7 @@ std::vector HloAliasAnalysis::ComputeBuffersAt( } // Sort and uniquify vector before returning. - std::sort(buffers.begin(), buffers.end(), HloBuffer::IdLessThan); + absl::c_sort(buffers, HloBuffer::IdLessThan); buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end()); return buffers; @@ -430,8 +429,7 @@ Status HloAliasAnalysis::Verify() const { for (const auto& pair : value_to_buffer_) { const HloValue* value = pair.first; const HloBuffer& buffer = *pair.second; - TF_RET_CHECK(std::find(buffer.values().begin(), buffer.values().end(), - value) != buffer.values().end()); + TF_RET_CHECK(absl::c_linear_search(buffer.values(), value)); } for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) { @@ -457,7 +455,7 @@ string HloAliasAnalysis::ToString() const { for (const HloComputation* computation : module_->computations()) { for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { ShapeUtil::ForEachSubshape( instruction->shape(), [&out, &instruction, this](const Shape&, const ShapeIndex& index) { @@ -515,7 +513,7 @@ StatusOr> HloAliasAnalysis::Run( auto& value_set = buffer_map.GetValuesInBuffer(buffer_number); std::vector sorted_values(value_set.begin(), value_set.end()); - std::sort(sorted_values.begin(), sorted_values.end(), HloValue::IdLessThan); + absl::c_sort(sorted_values, HloValue::IdLessThan); alias_analysis->buffers_.emplace_back(next_id++, sorted_values); for (const HloValue* value : sorted_values) { alias_analysis->value_to_buffer_[value] = @@ -533,11 +531,11 @@ bool HloAliasAnalysis::HasLiveRangeInterference( const HloOrdering& ordering) const { for (const HloBuffer& buffer : buffers()) { CHECK(!buffer.values().empty()); - if (ShapeUtil::IsToken(buffer.values().front()->shape())) { + if (buffer.values().front()->shape().IsToken()) { // Tokens have no on-device representation and cannot interfere. for (const HloValue* value : buffer.values()) { // If one of the values is a token, all values must be a token. - DCHECK(ShapeUtil::IsToken(value->shape())); + DCHECK(value->shape().IsToken()); } continue; } @@ -547,16 +545,15 @@ bool HloAliasAnalysis::HasLiveRangeInterference( // tie-break using value ID. The tie-break is necessary because we need a // strict weak order for std::sort. std::vector values = buffer.values(); - std::sort(values.begin(), values.end(), - [&ordering](const HloValue* a, const HloValue* b) { - if (ordering.IsDefinedBefore(*a, *b)) { - return true; - } else if (ordering.IsDefinedBefore(*b, *a)) { - return false; - } else { - return a->id() < b->id(); - } - }); + absl::c_sort(values, [&ordering](const HloValue* a, const HloValue* b) { + if (ordering.IsDefinedBefore(*a, *b)) { + return true; + } else if (ordering.IsDefinedBefore(*b, *a)) { + return false; + } else { + return a->id() < b->id(); + } + }); // Walk through the ordered vector of values. First verify that the values // are totally ordered with respect to 'ordering', then check that no diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 7e6150e94153cd15463725e862ce1b8593f2c991..b6dbf07959c541bceaa8eda5a0101503970ee832 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -238,13 +238,16 @@ TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); module_->AddEntryComputation(builder.Build()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); // Cannot alias an output twice. ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -279,13 +282,16 @@ TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module_->AddEntryComputation(builder.Build()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); // Cannot alias an output twice. ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -365,9 +371,11 @@ TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) { builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2})); module_->AddEntryComputation(builder.Build()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); const HloAliasAnalysis& analysis = RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index 9c3aa0e64d119c2560f4955d0bcb492519fa52a2..32e48651b30bace4723169935d1f10dd7d7bfec3 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -49,7 +49,7 @@ std::vector HloBuffer::ComputePositions() const { value->positions().end()); } // Remove duplicates and sort positions. - std::sort(positions.begin(), positions.end()); + absl::c_sort(positions); positions.erase(std::unique(positions.begin(), positions.end()), positions.end()); return positions; diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ff122b529bdcdcc69d2245136e19101902dbf957..817e15f9ff10a9b7e1a502265c85f70fdd681dd9 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include #include +#include #include #include #include @@ -207,14 +207,14 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( TF_RET_CHECK(instruction->user_count() == 0); TF_RET_CHECK(IsRemovable(instruction)) << "Cannot remove instruction: " << instruction->ToString(); - std::unordered_set removed; + absl::flat_hash_set removed; std::queue worklist; worklist.push(instruction); while (!worklist.empty()) { HloInstruction* item = worklist.front(); worklist.pop(); - if (removed.count(item) != 0 || item->user_count() != 0 || + if (removed.contains(item) || item->user_count() != 0 || item == root_instruction() || !IsRemovable(item) || (item->HasSideEffect() && item != instruction)) { continue; @@ -296,7 +296,7 @@ void ComputeComputationPostOrder(HloComputation* computation, } // namespace void HloComputation::ComputeInstructionPostOrder( - const HloComputation::ChannelDependencyMap& channel_dependency_map, + const HloComputation::ChannelDependencyGroup& channel_dependency_group, std::vector* post_order, HloInstruction* root, absl::flat_hash_map* visited) const { std::vector dfs_stack; @@ -320,66 +320,75 @@ void HloComputation::ComputeInstructionPostOrder( visited->insert({current, kVisiting}); - // Add the operands to the stack in reverse order so the first operand is - // processed first. This will produce a more natural ordering and a nicer - // result for things like HLO stringification. - const auto& operands = current->operands(); - for (int64 i = operands.size() - 1; i >= 0; --i) { - dfs_stack.emplace_back(operands[i]); - } - - for (HloInstruction* op : current->control_predecessors()) { - dfs_stack.emplace_back(op); - } - - // Add inputs for send->recv_done dependencies and cross-replica-sum - // dependencies. - switch (current->opcode()) { - case HloOpcode::kRecvDone: { - auto it = channel_dependency_map.find(current->channel_id()); - if (it != channel_dependency_map.end()) { - for (HloInstruction* op : it->second) { - dfs_stack.emplace_back(op); - } - } - break; + const auto get_channel_id = + [](HloInstruction* inst) -> absl::optional { + switch (inst->opcode()) { + case HloOpcode::kRecvDone: + return inst->channel_id(); + case HloOpcode::kAllReduce: + return inst->all_reduce_id(); + default: + return absl::nullopt; } - case HloOpcode::kCrossReplicaSum: { - auto all_reduce_id = current->all_reduce_id(); - if (all_reduce_id) { - auto it = channel_dependency_map.find(all_reduce_id.value()); - if (it != channel_dependency_map.end()) { - for (HloInstruction* op : it->second) { - dfs_stack.emplace_back(op); - } - } + }; + + // When adding a predecessor to the dfs_stack, we need to also add its + // associated channel dependencies. + const auto add_dfs_stack = [&](HloInstruction* inst) { + auto channel_id = get_channel_id(inst); + if (channel_id && channel_dependency_group.count(*channel_id)) { + auto it = channel_dependency_group.find(*channel_id); + for (HloInstruction* cinst : it->second) { + dfs_stack.emplace_back(cinst); } - break; + } else { + dfs_stack.emplace_back(inst); } - default: - break; + }; + + const auto add_predecessors = [&](HloInstruction* inst) { + // Add the operands to the stack in reverse order so the first operand is + // processed first. This will produce a more natural ordering and a nicer + // result for things like HLO stringification. + const auto& operands = inst->operands(); + for (int64 i = operands.size() - 1; i >= 0; --i) { + add_dfs_stack(operands[i]); + } + + for (HloInstruction* op : inst->control_predecessors()) { + add_dfs_stack(op); + } + }; + + // If the current instruction is a channel instruction, add the dependencies + // from all associated instructions of the channel. + auto channel_id = get_channel_id(current); + if (channel_id && channel_dependency_group.count(*channel_id)) { + auto it = channel_dependency_group.find(*channel_id); + for (HloInstruction* cinst : it->second) { + add_predecessors(cinst); + } + } else { + add_predecessors(current); } } } -HloComputation::ChannelDependencyMap +HloComputation::ChannelDependencyGroup HloComputation::ComputeChannelDependencies() const { - ChannelDependencyMap channel_dependency_map; + ChannelDependencyGroup channel_dependency_group; for (const auto& instruction : instructions_) { switch (instruction->opcode()) { - case HloOpcode::kSend: { - channel_dependency_map[instruction->channel_id()].push_back( + case HloOpcode::kSend: + case HloOpcode::kRecvDone: + channel_dependency_group[instruction->channel_id()].push_back( instruction.get()); break; - } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kAllReduce: { auto all_reduce_id = instruction->all_reduce_id(); if (all_reduce_id) { - auto& dependencies = channel_dependency_map[all_reduce_id.value()]; - absl::c_copy(instruction->operands(), - std::back_inserter(dependencies)); - absl::c_copy(instruction->control_predecessors(), - std::back_inserter(dependencies)); + channel_dependency_group[all_reduce_id.value()].push_back( + instruction.get()); } break; } @@ -387,15 +396,16 @@ HloComputation::ComputeChannelDependencies() const { break; } } - return channel_dependency_map; + return channel_dependency_group; } std::vector HloComputation::MakeInstructionPostOrder() const { - auto channel_dependency_map = ComputeChannelDependencies(); + auto channel_dependency_group = ComputeChannelDependencies(); std::vector post_order; post_order.reserve(instruction_count()); std::vector trace_instructions; absl::flat_hash_map visited; + visited.reserve(instruction_count()); for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -403,7 +413,7 @@ std::vector HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - ComputeInstructionPostOrder(channel_dependency_map, &post_order, + ComputeInstructionPostOrder(channel_dependency_group, &post_order, instruction.get(), &visited); } } @@ -530,11 +540,10 @@ HloComputation::CreateFromProto( HloInstruction* root = instruction_map.at(proto.root_id()); // Sort the instructions in the proto id's order. - std::sort(instructions.begin(), instructions.end(), - [&](const std::unique_ptr& a, - const std::unique_ptr& b) { - return to_proto_id[a.get()] < to_proto_id[b.get()]; - }); + absl::c_sort(instructions, [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); TF_RETURN_IF_ERROR([&]() -> Status { std::vector parameters_seen(parameter_count); @@ -599,7 +608,7 @@ StatusOr HloComputation::DeepCopyHelper( const std::function< HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, HloComputation* computation)>& copy_leaf) { - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { std::vector elements; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); i++) { @@ -616,14 +625,14 @@ StatusOr HloComputation::DeepCopyHelper( } return AddInstruction(HloInstruction::CreateTuple(elements)); } - if (ShapeUtil::IsToken(instruction->shape())) { + if (instruction->shape().IsToken()) { // Tokens have no on-device representation and cannot be copied. Pass // through transparently. return instruction; } // Array shape. - TF_RET_CHECK(ShapeUtil::IsArray(instruction->shape())); + TF_RET_CHECK(instruction->shape().IsArray()); return copy_leaf(instruction, *index, this); } @@ -693,25 +702,37 @@ bool HloComputation::operator==(const HloComputation& other) const { if (this == &other) { return true; } - std::set> visited; - std::function eq = - [&visited, &eq](const HloInstruction* a, const HloInstruction* b) { - // If are visited but not identical, the recursion should have - // been aborted. So, if are visited at this point, they must be - // identical. - if (visited.count(std::make_pair(a, b)) > 0) { - return true; - } - visited.emplace(a, b); - return a->Identical( - *b, eq, [](const HloComputation* a, const HloComputation* b) { - return *a == *b; - }); - }; - return eq(root_instruction(), other.root_instruction()); -} + absl::flat_hash_set> + visited; + std::vector> worklist; + + worklist.push_back({root_instruction(), other.root_instruction()}); -uint64 HloComputation::Hash() const { return root_instruction()->Hash(); } + while (!worklist.empty()) { + auto pair = worklist.back(); + worklist.pop_back(); + + if (visited.contains(pair)) { + continue; + } + visited.emplace(pair); + // TODO(b/123082518): Avoid recursively invoking == becasue it may + // cause a stack overflow with deeply nested subcomputations. + bool identical_ignoring_operands = pair.first->Identical( + *pair.second, + [](const HloInstruction*, const HloInstruction*) { return true; }, + [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }); + if (!identical_ignoring_operands) { + return false; + } + for (size_t i = 0; i < pair.first->operands().size(); ++i) { + worklist.push_back({pair.first->operand(i), pair.second->operand(i)}); + } + } + return true; +} Status HloComputation::ReplaceWithNewInstruction( HloInstruction* old_instruction, @@ -797,20 +818,19 @@ Status HloComputation::AcceptWithOperandOrder( template Status HloComputation::AcceptOrdered( DfsHloVisitorBase* visitor, - const std::vector& order) const { + absl::Span order) const { VLOG(3) << "Accepting visitor with order."; for (HloInstruction* root : CollectUnreachableRoots()) { - TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) - << root->ToString(); + TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString(); } TF_RET_CHECK(order.size() == instruction_count()); - std::unordered_set visited; + absl::flat_hash_set visited; for (const HloInstruction* instruction : order) { VLOG(3) << "Visiting ordered: " << instruction->ToString(); - TF_RET_CHECK(instruction_iterators_.count(instruction) == 1) + TF_RET_CHECK(instruction_iterators_.contains(instruction)) << "Instruction " << instruction->name() << " is not in computation " << name(); - TF_RET_CHECK(visited.count(instruction) == 0) + TF_RET_CHECK(!visited.contains(instruction)) << "Instruction " << instruction->name() << " appears more than once in order"; HloInstruction* mutable_instruction = @@ -827,9 +847,9 @@ Status HloComputation::AcceptOrdered( // Explicit instantiations. template Status HloComputation::AcceptOrdered( - DfsHloVisitor*, const std::vector&) const; + DfsHloVisitor*, absl::Span) const; template Status HloComputation::AcceptOrdered( - ConstDfsHloVisitor*, const std::vector&) const; + ConstDfsHloVisitor*, absl::Span) const; Status HloComputation::Accept( const std::function& visitor_func) { @@ -846,29 +866,31 @@ Status HloComputation::Accept( std::unique_ptr HloComputation::Clone( const string& suffix, HloCloneContext* context) { return CloneWithReplacements( - /*replacements=*/std::unordered_map>(), - context, suffix); + /*replacements=*/absl::flat_hash_map>(), + /*extra_parameters=*/{}, context, suffix); } std::unique_ptr HloComputation::CloneWithReplacementPairs( std::pair> r1, HloCloneContext* context, const string& suffix) { - std::unordered_map> + absl::flat_hash_map> replacements; replacements.emplace(std::move(r1)); - return CloneWithReplacements(std::move(replacements), context, suffix); + return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacementPairs( std::pair> r1, std::pair> r2, HloCloneContext* context, const string& suffix) { - std::unordered_map> + absl::flat_hash_map> replacements; replacements.emplace(std::move(r1)); replacements.emplace(std::move(r2)); - return CloneWithReplacements(std::move(replacements), context, suffix); + return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacementPairs( @@ -876,17 +898,19 @@ std::unique_ptr HloComputation::CloneWithReplacementPairs( std::pair> r2, std::pair> r3, HloCloneContext* context, const string& suffix) { - std::unordered_map> + absl::flat_hash_map> replacements; replacements.emplace(std::move(r1)); replacements.emplace(std::move(r2)); replacements.emplace(std::move(r3)); - return CloneWithReplacements(std::move(replacements), context, suffix); + return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( - std::unordered_map> + absl::flat_hash_map> replacements, + absl::Span extra_parameters, HloCloneContext* context, const string& suffix) { std::unique_ptr context_ptr; if (context == nullptr) { @@ -952,6 +976,12 @@ std::unique_ptr HloComputation::CloneWithReplacements( } std::vector> instructions; + // First add the extra parameters to 'instructions'. + for (const auto& instr : extra_parameters) { + CHECK_EQ(instr->opcode(), HloOpcode::kParameter) + << "Only parameter instructions are allowed in 'extra_parameters'"; + instructions.emplace_back(instr->Clone()); + } for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index c584e4c7ca5770533f28352b0df9dadd9dbe1860..212dfa15a13185f1050103739fad8b560270d401 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -264,12 +263,6 @@ class HloComputation { // Return whether `*this` and `other` are functionally equivalent. bool operator==(const HloComputation& other) const; - // Generates a hash value of an HLO computation. Hash considers - // information on opcode, shape, operands, and typically a root instruction. - // This function returns the same hash value for equivalent HLO computations, - // with respect to HloInstruction::Identical() method. - uint64 Hash() const; - // Replaces old instruction with newly created instruction. Removes old // instruction from computation. Updates uses and root instruction. Status ReplaceWithNewInstruction( @@ -307,7 +300,7 @@ class HloComputation { // be a topological sort of all instructions in the computation. template Status AcceptOrdered(DfsHloVisitorBase* visitor, - const std::vector& order) const; + absl::Span order) const; // Same as Accept() above, but the visitor is given as a function. Status Accept(const std::function& visitor_func); @@ -329,11 +322,16 @@ class HloComputation { // that's not already in the computation, it's cloned and added to the new // computation. // + // 'extra_parameters' allows to specify additional parameters that should be + // added to the computation. + // // All relevant instructions are cloned, *including* unique_ptr in the // `replacements` map. std::unique_ptr CloneWithReplacements( - std::unordered_map> + absl::flat_hash_map> replacements, + absl::Span extra_parameters = {}, HloCloneContext* context = nullptr, const string& suffix = "clone"); // Convenience overloads for CloneWithReplacements. You want to do @@ -371,13 +369,13 @@ class HloComputation { // channel complete). bool IsRemovable(const HloInstruction* instruction); - // Returns a map from channel-id to directed dependencies of the channel - // instructions. For send&recv pairs it means the send instruction and for - // cross-replica-sum the union of the dependencies for all participating - // instructions. - using ChannelDependencyMap = + // Returns a map from channel-id to the group of instructions associated with + // the channel. These instructions will be considered as a single node for + // dependency purposes. Send and RecvDone are in the group, and AllReduces + // with the same channel id are in the group. + using ChannelDependencyGroup = absl::flat_hash_map>; - ChannelDependencyMap ComputeChannelDependencies() const; + ChannelDependencyGroup ComputeChannelDependencies() const; // Returns true if this computation has a side effect. A computation has a // side effect if it contains one or more instructions with a side effect. @@ -393,6 +391,10 @@ class HloComputation { fusion_instruction_ = fusion_instruction; } + // Clear the unique ID of the computation so that it can be re-assigned, such + // as for the purpose of compacting the unique IDs. + void ClearUniqueIdInternal() { unique_id_ = -1; } + // The id of this computation should be unique within the module. void SetUniqueId(int64 id) { CHECK_EQ(unique_id_, -1); @@ -436,7 +438,7 @@ class HloComputation { enum VisitState { kVisiting, kVisited }; void ComputeInstructionPostOrder( - const HloComputation::ChannelDependencyMap& channel_dependency_map, + const HloComputation::ChannelDependencyGroup& channel_dependency_map, std::vector* post_order, HloInstruction* root, absl::flat_hash_map* visited) const; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 0361c87428f6e4c031d95492a5bc782ad388e5b5..fe37ca6b3963430c765f27aede4f506366fc5d97 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -15,12 +15,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include #include +#include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -33,6 +39,7 @@ namespace xla { namespace { namespace m = match; +namespace op = xla::testing::opcode_matchers; using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; @@ -226,7 +233,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { : computation_(computation) {} Status DefaultAction(HloInstruction* hlo_instruction) override { - EXPECT_EQ(0, visited_set_.count(hlo_instruction)); + EXPECT_FALSE(visited_set_.contains(hlo_instruction)); visited_set_.insert(hlo_instruction); last_visited_ = hlo_instruction; return Status::OK(); @@ -239,7 +246,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { } HloComputation* computation_; - std::set visited_set_; + absl::flat_hash_set visited_set_; int64 finish_visit_calls_ = 0; HloInstruction* last_visited_ = nullptr; }; @@ -491,6 +498,41 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add)); } +TEST_F(HloComputationTest, CloneWithReplacements) { + auto builder = HloComputation::Builder(TestName()); + Shape r0s64 = ShapeUtil::MakeShape(S64, {}); + Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + Shape r0u32 = ShapeUtil::MakeShape(U32, {}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "p.0.lhs")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs")); + auto param2 = + builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1")); + auto lt = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param0, param1)); + auto module = CreateNewVerifiedModule(); + auto computation = + module->AddEntryComputation(builder.Build(/*root_instruction=*/lt)); + absl::flat_hash_map> + replacements; + replacements.emplace(param2, + HloInstruction::CreateParameter(2, r0s32, "p.1")); + auto param3 = HloInstruction::CreateParameter(3, r0u32, "p.2"); + std::vector extra_parameters{param3.get()}; + auto clone = computation->CloneWithReplacements(std::move(replacements), + extra_parameters); + ASSERT_EQ(clone->num_parameters(), 4); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(0)->shape(), r0f32_)); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(1)->shape(), r0f32_)); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(2)->shape(), r0s32)); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(3)->shape(), r0u32)); +} + TEST_F(HloComputationTest, Stringification) { const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); @@ -606,5 +648,57 @@ TEST_F(HloComputationTest, StringificationCanonical) { EXPECT_EQ(computation->ToString(options), expected_computation2); } +std::unique_ptr MakeAddNComputation(int n) { + auto builder = HloComputation::Builder("add_n"); + auto result = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "x_value")); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + for (int i = 0; i < n; ++i) { + result = builder.AddInstruction(HloInstruction::CreateBinary( + one->shape(), HloOpcode::kAdd, result, one)); + } + return builder.Build(); +} + +TEST_F(HloComputationTest, DeepEquality) { + auto computation_a = MakeAddNComputation(200000); + auto computation_b = MakeAddNComputation(200000); + EXPECT_TRUE(*computation_a == *computation_b); + + auto computation_c = MakeAddNComputation(199999); + EXPECT_FALSE(*computation_a == *computation_c); + EXPECT_FALSE(*computation_c == *computation_b); +} + +// Tests that cross-module AllReduce instructions are ordered before all their +// predecessors and after all their successors. +TEST_F(HloComputationTest, InstructionPostOrderWithAllReduce) { + const char* const hlo_string = R"( +HloModule Module + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY entry { + param = f32[128] parameter(0), sharding={maximal device=0} + crs0 = f32[128] all-reduce(param), + replica_groups={{0}}, all_reduce_id=1, barrier="", to_apply=add, + sharding={maximal device=0} + crs1 = f32[128] all-reduce(param), + replica_groups={{0}}, all_reduce_id=1, barrier="", to_apply=add, + sharding={maximal device=1} + add = f32[128] add(crs0, crs0), sharding={maximal device=0} + ROOT t = (f32[128], f32[128]) tuple(add, crs1) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + EXPECT_THAT(module->entry_computation()->MakeInstructionPostOrder(), + ElementsAre(op::Parameter(), op::AllReduce(), op::AllReduce(), + op::Add(), op::Tuple())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 5e37883d3d8d5067bab873ac6b5f732e7360c5fa..e7ed858e8c5af83d08863d64a0aba162c75ed5cb 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -35,6 +35,34 @@ limitations under the License. namespace xla { +// Checks whether instr is or transitively contains an instruction that we +// shouldn't fold. +// +// Specifically, we don't fold kRng or kAfterAll instructions: +// +// - kRng is already marked as side-effecting and so is skipped elsewhere, but +// we check for it here. Even kRng weren't side-effecting and took an +// explicit seed, we *still* wouldn't want to constant-fold it, because the +// evaluator's handling of rng is not guaranteed to be identical to any +// particular backend's rng. +// +// - kAfterAll needs to be skipped because a kAfterAll op with no args can +// currently materialize a token "out of thin air". TODO(b/110532604): +// Remove this check once AfterAll requires at least one operand, in which +// case constant folding will be impossible. +static bool IsOrContainsIllegalInstr(const HloInstruction* instr) { + if (instr->opcode() == HloOpcode::kAfterAll || + instr->opcode() == HloOpcode::kRng) { + return true; + } + for (const HloComputation* c : instr->called_computations()) { + if (absl::c_any_of(c->instructions(), IsOrContainsIllegalInstr)) { + return true; + } + } + return false; +} + StatusOr HloConstantFolding::Run(HloModule* module) { // Limit the constant folding to 0 iterations to skip folding loops. This // retains the behavior from before while loop support in HloEvaluator and may @@ -52,25 +80,24 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, Tuple, AfterAll operation. - // Tuple constants are not directly supported by any backends, hence - // folding Tuple is not useful and would in fact be expanded back into - // kTuple by Algebraic Simplifier. - // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one - // operand in which case constant folding will be impossible and this - // special case is not necessary. - if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kTuple || - instruction->opcode() == HloOpcode::kAfterAll) { - continue; - } // Skip instructions with non-constant operands. if (!hlo_query::AllOperandsAreConstants(*instruction)) { continue; } + // Don't fold Constant, Parameter, and Tuple instructions. Tuple + // constants are not directly supported by any backends, hence folding + // Tuple is not useful and would in fact be expanded back into kTuple by + // Algebraic Simplifier. + // + // (We do allow folding subcomputations that contain these instructions.) + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant || + instruction->opcode() == HloOpcode::kTuple) { + continue; + } + // Broadcasts dramatically increase the size of constants, which is often // detrimental to performance and memory capacity, so do not fold // broadcasts. @@ -79,12 +106,23 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } + // Check for instructions that we can't fold even if they appear inside of + // a subcomputation (e.g. a kCall). + if (IsOrContainsIllegalInstr(instruction)) { + continue; + } + + // Don't constant-fold side-effecting instructions or instructions which + // contain side-effecting instructions. + if (instruction->HasSideEffect()) { + continue; + } + // Don't constant fold unless it's a net positive or the output is small. - if (ShapeUtil::IsArray(instruction->shape())) { + if (instruction->shape().IsArray()) { int64 elements_in_removed_operands = 0; for (HloInstruction* operand : instruction->operands()) { - if (operand->user_count() == 1 && - ShapeUtil::IsArray(operand->shape())) { + if (operand->user_count() == 1 && operand->shape().IsArray()) { elements_in_removed_operands += ShapeUtil::ElementsIn(operand->shape()); } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 4f81dc94e577a63c09ae4019e5e8158252c712ce..4bdc980c9ac4fb79cde0242f407aea7057474b27 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -252,7 +252,7 @@ const char* const kConstantFoldLargePad = R"( HloModule ConstantFoldLargePad ENTRY r { - a = f32[1,1,1] constant(f32[1,1,1]{{{7}}}) + a = f32[1,1,1] constant({{{7}}}) b = f32[] constant(42) ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63 })"; @@ -268,5 +268,51 @@ TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) { GmockMatch(m::Pad(m::Constant(), m::Constant()))); } +TEST_F(HloConstantFoldingTest, DontFoldSubcomputationContainingAfterAll) { + const char* const kModuleStr = R"( + HloModule test + + Fn { + tok = token[] after-all() + ROOT root = f32[10] iota(), iota_dimension=0 + } + + ENTRY entry { + ROOT call = f32[10] call(), to_apply=Fn + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + HloConstantFolding constant_folding; + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_folding, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(HloConstantFoldingTest, + DontFoldSubcomputationTransitivelyContainingRng) { + const char* const kModuleStr = R"( + HloModule test + + InnerFn { + c0 = f32[] constant(0) + c1 = f32[] constant(1) + ROOT rng = f32[10] rng(c0, c1), distribution=rng_uniform + } + + Fn { + ROOT fusion = f32[10] fusion(), kind=kLoop, calls=InnerFn + } + + ENTRY entry { + ROOT call = f32[10] call(), to_apply=Fn + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + HloConstantFolding constant_folding; + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_folding, module.get())); + EXPECT_FALSE(result); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index df7d3826dbad1f264a5dc53312c062900155b0f6..6d9e01e3a77b1cdb5d9bad69bb2754e3ce3380c0 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -91,9 +91,10 @@ Status HloCostAnalysis::HandleElementwiseOp( auto opcode = hlo_instruction->opcode(); // We treat transcendental operations separately since one transcendental // operation can correspond to several floating point ops. - if (opcode == HloOpcode::kExp || opcode == HloOpcode::kPower || - opcode == HloOpcode::kTanh || opcode == HloOpcode::kSin || - opcode == HloOpcode::kCos) { + if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog || + opcode == HloOpcode::kPower || opcode == HloOpcode::kSqrt || + opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh || + opcode == HloOpcode::kSin || opcode == HloOpcode::kCos) { current_properties_[kTranscendentalsKey] = computation_count; } else { // Note: transcendental operations are considered a separate category from @@ -237,24 +238,17 @@ Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) { Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); - const Shape& rhs_shape = dot->operand(1)->shape(); + const Shape& dot_shape = dot->shape(); const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); // Count of elements along the reduction dimension (last dimension for the // rhs). - int64 reduction_width = - lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0)); - // First divide by reduction width before multiplying by rhs elements to avoid - // overflow. - int64 fma_count; - if (reduction_width == 0) { - fma_count = 0; - } else { - fma_count = (ShapeUtil::ElementsIn(lhs_shape) / reduction_width) * - ShapeUtil::ElementsIn(rhs_shape); + int64 reduction_width = 1; + for (auto dim : dnums.lhs_contracting_dimensions()) { + reduction_width *= lhs_shape.dimensions(dim); } - - // We count an FMA operation as 2 floating point operations. - current_properties_[kFlopsKey] = kFmaFlops * fma_count; + // Each output elment requires reduction_width FMA operations. + current_properties_[kFlopsKey] = + kFmaFlops * ShapeUtil::ElementsIn(dot_shape) * reduction_width; return Status::OK(); } @@ -292,7 +286,7 @@ Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { // does not need to be multiplied by the number of input tensors - that's // already "priced in" by the sub-computation doing more work. auto arg = reduce->operand(0); - auto output_shape = ShapeUtil::IsArray(reduce->shape()) + auto output_shape = reduce->shape().IsArray() ? reduce->shape() : reduce->shape().tuple_shapes(0); int64 reduction_count = @@ -531,7 +525,8 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { } const int64 fma_count = (input_feature / convolution->feature_group_count()) * - output_feature * batch * + output_feature * + (batch / convolution->batch_group_count()) * Product(valid_position_counts); current_properties_[kFlopsKey] = fma_count * kFmaFlops; return Status::OK(); @@ -539,7 +534,7 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { Status HloCostAnalysis::HandleFft(const HloInstruction* fft) { auto real_shape = - ShapeUtil::IsTuple(fft->operand(0)->shape()) + fft->operand(0)->shape().IsTuple() ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0) : fft->operand(0)->shape(); constexpr int kFmaPerComplexMul = 4; @@ -552,7 +547,22 @@ Status HloCostAnalysis::HandleFft(const HloInstruction* fft) { return Status::OK(); } -Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { +Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) { + float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f; + bytes_accessed += GetShapeSize(hlo->operand(1)->shape()); + current_properties_[kBytesAccessedKey] = bytes_accessed; + + const Shape& a_shape = hlo->operand(0)->shape(); + const Shape& b_shape = hlo->operand(1)->shape(); + // Estimate as batch * mn^2 / 2 flops. + int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1); + elems *= ShapeUtil::ElementsIn(b_shape); + // Each output elment requires reduction_widht FMA operations. + current_properties_[kFlopsKey] = kFmaFlops * elems; + return Status::OK(); +} + +Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) { // We assume 2 replicas, so that each output element is the sum of two input // elements. // @@ -561,7 +571,7 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { double flops = 0.0; ShapeUtil::ForEachSubshape(crs->shape(), [&](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { flops += ShapeUtil::ElementsIn(subshape); } }); @@ -577,6 +587,10 @@ Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { return Status::OK(); } +Status HloCostAnalysis::HandleReplicaId(const HloInstruction* /*hlo*/) { + return Status::OK(); +} + Status HloCostAnalysis::HandleRng(const HloInstruction* random) { // TODO(b/26346211): Implement better estimates for the RNG cost, since the // cost changes with the implementation and the distribution. For now, assume diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 33983119c9b00a248c0e8dcc5815c6367192dca3..96357dec68e390251c43c2c3fc6f5a5612063fbd 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -71,9 +71,11 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleDot(const HloInstruction* dot) override; Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; - Status HandleCrossReplicaSum(const HloInstruction* crs) override; + Status HandleTriangularSolve(const HloInstruction* hlo) override; + Status HandleAllReduce(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleCollectivePermute(const HloInstruction* hlo) override; + Status HandleReplicaId(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; Status HandleRng(const HloInstruction* random) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index ff32faf298dd1f04c5b769f2a88f76a7a1e18ae7..4d42770ba784ba15fae9518b40a75d8a2f038e66 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/compiler/xla/statusor.h" @@ -157,6 +158,87 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { sizeof(float) * (10 * 5 + 5 * 30 + 10 * 30)); } +TEST_F(HloCostAnalysisTest, DotGeneral) { + XlaBuilder builder("matrix_multiply"); + auto lhs = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs"); + auto rhs = + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs"); + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(0); + dnums.add_rhs_contracting_dimensions(1); + DotGeneral(lhs, rhs, dnums); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Check the number of computations returned from the analysis (1500 FMAs). + EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5); + + EXPECT_EQ(analysis.transcendental_count(), 0); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 10 * 30)); +} + +TEST_F(HloCostAnalysisTest, DotGeneral2) { + XlaBuilder builder("matrix_multiply"); + auto lhs = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs"); + auto rhs = + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs"); + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(2); + dnums.add_rhs_contracting_dimensions(0); + dnums.add_rhs_batch_dimensions(1); + DotGeneral(lhs, rhs, dnums); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Check the number of computations returned from the analysis (1500 FMAs). + EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5); + + EXPECT_EQ(analysis.transcendental_count(), 0); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 5 * 10 * 30)); +} + +TEST_F(HloCostAnalysisTest, DotGeneral3) { + XlaBuilder builder("matrix_multiply"); + auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); + auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); + DotDimensionNumbers dnums; + DotGeneral(lhs, rhs, dnums); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Check the number of computations returned from the analysis (1500 FMAs). + EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5); + + EXPECT_EQ(analysis.transcendental_count(), 0); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (10 * 5 + 5 * 30 + 5 * 5 * 10 * 30)); +} + TEST_F(HloCostAnalysisTest, Map) { XlaBuilder builder("map"); auto input = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10}), "in"); @@ -529,7 +611,8 @@ TEST_F(HloCostAnalysisTest, DynamicSlice) { // Test the analysis on a slice. XlaBuilder builder("dynamic-slice"); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); - DynamicSlice(x, ConstantR1(&builder, {1}), {1}); + DynamicSlice(x, absl::Span({ConstantR0(&builder, 1)}), + {1}); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. @@ -545,7 +628,7 @@ TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { XlaBuilder builder("dynamic-update-slice"); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); DynamicUpdateSlice(x, ConstantR1(&builder, {1.0}), - ConstantR1(&builder, {1})); + absl::Span({ConstantR0(&builder, 1)})); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index b2005d3c210d4ae7e3702cb9624c3ad98056984c..b5d9e8e7f1a703d5d914a12d5226d53821071be6 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -17,9 +17,15 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -69,11 +75,11 @@ StatusOr MakeConvolveHlo( CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape( - lhs->shape(), rhs->shape(), feature_group_count, + lhs->shape(), rhs->shape(), feature_group_count, 1, window, dimension_numbers)); return computation->AddInstruction(HloInstruction::CreateConvolve( - convolve_shape, lhs, rhs, feature_group_count, window, dimension_numbers, - precision_config)); + convolve_shape, lhs, rhs, feature_group_count, 1, window, + dimension_numbers, precision_config)); } StatusOr MakeTransposeHlo(HloInstruction* operand, @@ -105,12 +111,26 @@ StatusOr MakeDynamicSliceHlo( absl::Span slice_sizes) { HloComputation* computation = operand->parent(); CHECK_EQ(computation, start_indices->parent()); + int64 rank = start_indices->shape().dimensions(0); + std::vector scalar_start_indices; + for (int i = 0; i < rank; ++i) { + // TODO(b/118437727): Update callers to provide scalars directly. + auto slice = computation->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}), + start_indices, {i}, {i + 1}, {1})); + scalar_start_indices.push_back( + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {}), + slice))); + } + std::vector scalar_start_indices_shapes( + rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {})); TF_ASSIGN_OR_RETURN( Shape dynamic_slice_shape, ShapeInference::InferDynamicSliceShape( - operand->shape(), start_indices->shape(), slice_sizes)); + operand->shape(), scalar_start_indices_shapes, slice_sizes)); return computation->AddInstruction(HloInstruction::CreateDynamicSlice( - dynamic_slice_shape, operand, start_indices, slice_sizes)); + dynamic_slice_shape, operand, scalar_start_indices, slice_sizes)); } StatusOr MakeDynamicUpdateSliceHlo( @@ -119,17 +139,31 @@ StatusOr MakeDynamicUpdateSliceHlo( HloComputation* computation = operand->parent(); CHECK_EQ(computation, update->parent()); CHECK_EQ(computation, start_indices->parent()); + int64 rank = start_indices->shape().dimensions(0); + std::vector scalar_start_indices; + for (int i = 0; i < rank; ++i) { + // TODO(b/118437727): Update callers to provide scalars directly. + auto slice = computation->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}), + start_indices, {i}, {i + 1}, {1})); + scalar_start_indices.push_back( + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {}), + slice))); + } + std::vector scalar_start_indices_shapes( + rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {})); TF_ASSIGN_OR_RETURN( Shape dynamic_update_slice_shape, ShapeInference::InferDynamicUpdateSliceShape( - operand->shape(), update->shape(), start_indices->shape())); + operand->shape(), update->shape(), scalar_start_indices_shapes)); return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - dynamic_update_slice_shape, operand, update, start_indices)); + dynamic_update_slice_shape, operand, update, scalar_start_indices)); } -StatusOr MakeBroadcastHlo( - HloInstruction* operand, absl::Span broadcast_dimensions, - absl::Span result_shape_bounds) { +HloInstruction* MakeBroadcastHlo(HloInstruction* operand, + absl::Span broadcast_dimensions, + absl::Span result_shape_bounds) { HloComputation* computation = operand->parent(); Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(), result_shape_bounds); @@ -189,8 +223,7 @@ StatusOr MakeMapHlo(absl::Span operands, for (const HloInstruction* operand : operands) { CHECK_EQ(computation, operand->parent()); operand_shapes.push_back(&operand->shape()); - max_operand_rank = - std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + max_operand_rank = std::max(max_operand_rank, operand->shape().rank()); } std::vector map_dims(max_operand_rank); std::iota(map_dims.begin(), map_dims.end(), 0); @@ -207,7 +240,7 @@ StatusOr MakeReduceHlo(HloInstruction* operand, HloOpcode binary_opcode, HloModule* module) { DCHECK_NE(nullptr, module); - std::vector all_dims(ShapeUtil::Rank(operand->shape())); + std::vector all_dims(operand->shape().rank()); std::iota(all_dims.begin(), all_dims.end(), 0); auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {}); @@ -240,6 +273,29 @@ StatusOr MakeSelectHlo(HloInstruction* pred, select_shape, HloOpcode::kSelect, pred, on_true, on_false)); } +StatusOr MakeSortHlo( + const Shape& sort_shape, absl::Span operands, + int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder, + HloModule* module) { + CHECK(!operands.empty()) << "Sort Hlo requires at least one operand."; + HloComputation* compare_computation; + XlaBuilder b("Sort.Compare"); + std::vector operand_types(operands.size()); + for (int64 i = 0; i < operands.size(); ++i) { + operand_types[i] = operands[i]->shape().element_type(); + } + XlaComputation comparator = CreateScalarLtComputation(operand_types, &b); + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, + HloModule::CreateFromProto(comparator.proto(), config)); + HloCloneContext context(module); + compare_computation = + module->DeepCloneComputation(new_module->entry_computation(), &context); + return builder->AddInstruction(HloInstruction::CreateSort( + sort_shape, dimension_to_sort, operands, compare_computation, is_stable)); +} + StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { CHECK_GT(n, 0); @@ -366,9 +422,9 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, return MakePadHlo(operand, zero, padding_config); } -StatusOr BroadcastZeros( - HloComputation* computation, PrimitiveType element_type, - absl::Span broadcast_dimensions) { +HloInstruction* BroadcastZeros(HloComputation* computation, + PrimitiveType element_type, + absl::Span broadcast_dimensions) { HloInstruction* zero = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 8e5ddbbd503a501bd493aec43a2ccd4db883ef0c..17b7a2da6a9da994ea2d496b549eec79278b56b5 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -82,9 +82,9 @@ StatusOr MakeDynamicUpdateSliceHlo( // Creates a broadcast HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeBroadcastHlo( - HloInstruction* operand, absl::Span broadcast_dimensions, - absl::Span result_shape_bounds); +HloInstruction* MakeBroadcastHlo(HloInstruction* operand, + absl::Span broadcast_dimensions, + absl::Span result_shape_bounds); // Creates a GetTupleElement HLO instruction and adds it to the computation // containing `operand`. @@ -123,6 +123,15 @@ StatusOr MakeSelectHlo(HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false); +// Creates a Sort HLO instruction and adds it to the computation containing the +// operands. All operands must be in the same computation. Also creates a +// default compare sub-computation which sorts the first operand into ascending +// order. 'is_stable' specifies whether the sorting should be stable. +StatusOr MakeSortHlo( + const Shape& sort_shape, absl::Span operands, + int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder, + HloModule* module); + // Creates an R1 Constant HLO instruction of the given PrimitiveType with the // given values and adds it to the given computation. template @@ -198,9 +207,9 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, // Broadcasts a zero value of type `element_type` into a tensor with element // type `element_type` and dimension bounds `broadcast_dimensions`. The // broadcast instruction is emitted into `computation`. -StatusOr BroadcastZeros( - HloComputation* computation, PrimitiveType element_type, - absl::Span broadcast_dimensions); +HloInstruction* BroadcastZeros(HloComputation* computation, + PrimitiveType element_type, + absl::Span broadcast_dimensions); // Creates a HLO computation that takes arguments of type `domain` and produces // a value of type `range`. diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index aaa9ec60eb3c4e0159ed40b37d772e0973d306ec..6025e6a77941369f75ebaa98bdf0979669b3a03c 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -56,9 +56,9 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { entry_computation->set_root_instruction(first_1_dims_collapsed); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR1({3, 4})})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({3, 4})})); CHECK_EQ(result_literal, LiteralUtil::CreateR1({3, 4})); } @@ -77,10 +77,9 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate( - *module, - {LiteralUtil::CreateR3( - {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, + {{-1, -2}, {-3, -4}, {-5, -6}}})})); CHECK_EQ(result_literal, LiteralUtil::CreateR2( {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); @@ -101,8 +100,7 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, - {LiteralUtil::CreateR1({9, 10})})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({9, 10})})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9, 10}})); } @@ -121,8 +119,7 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, - {LiteralUtil::CreateR1({9, 10})})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({9, 10})})); CHECK_EQ(result_literal, LiteralUtil::CreateR3({{{9, 10}}})); } @@ -141,7 +138,7 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, {LiteralUtil::CreateR0(9)})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(9)})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9}})); } @@ -160,8 +157,8 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); + evaluator.Evaluate(*module, + {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); CHECK_EQ(result_literal, LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } @@ -180,9 +177,9 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { entry_computation->set_root_instruction(zero_padded_param); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR1({3, 4})})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({3, 4})})); CHECK_EQ(result_literal, LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); } @@ -194,15 +191,14 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { /*output_shape_dims=*/{2, 2}, ¶m, &entry_computation); - TF_ASSERT_OK_AND_ASSIGN( - HloInstruction * zeros, - BroadcastZeros(module->entry_computation(), S32, {2, 2})); + HloInstruction* zeros = + BroadcastZeros(module->entry_computation(), S32, {2, 2}); entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0)})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0)})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{0, 0}, {0, 0}})); } @@ -214,15 +210,14 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { /*output_shape_dims=*/{2, 2}, ¶m, &entry_computation); - TF_ASSERT_OK_AND_ASSIGN( - HloInstruction * zeros, - BroadcastZeros(module->entry_computation(), F32, {2, 2})); + HloInstruction* zeros = + BroadcastZeros(module->entry_computation(), F32, {2, 2}); entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR0(0.0f)})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0.0f)})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); } diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index e602107cbe64320a8e8e740168cb294ec6be9667..849cac278ee379122ba1ff9fade3bf003969b8a7 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 3ed3d3c11c71dc534f193ba3ffb556b0eb0c80e4..3144a84805454488f417391f40ed6b9e9facc752 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -107,7 +107,7 @@ bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( return false; } } - if (!visited.count(user)) { + if (!visited.contains(user)) { stack.push_back(user); } } @@ -190,7 +190,7 @@ string HloDataflowAnalysis::ToString() const { for (const HloComputation* computation : module_.computations()) { for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { GetInstructionValueSet(instruction) .ForEachElement([this, &instruction, &out]( const ShapeIndex& index, @@ -256,7 +256,7 @@ bool HloDataflowAnalysis::Phi( input_value_ids.push_back(value->id()); } } - std::sort(input_value_ids.begin(), input_value_ids.end()); + absl::c_sort(input_value_ids); input_value_ids.erase( std::unique(input_value_ids.begin(), input_value_ids.end()), input_value_ids.end()); @@ -271,8 +271,7 @@ bool HloDataflowAnalysis::Phi( if (current_value_defined_here) { VLOG(5) << "current_value_defined_here: " << current_value->ToString(); CHECK(current_value->is_phi()); - auto it = std::find(input_value_ids.begin(), input_value_ids.end(), - current_value->id()); + auto it = absl::c_find(input_value_ids, current_value->id()); if (it != input_value_ids.end()) { input_value_ids.erase(it); } @@ -921,8 +920,7 @@ StatusOr> HloDataflowAnalysis::Run( for (auto& pair : dataflow_analysis->values_) { dataflow_analysis->values_vector_.push_back(&pair.second); } - std::sort(dataflow_analysis->values_vector_.begin(), - dataflow_analysis->values_vector_.end(), HloValue::IdLessThan); + absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan); TF_DCHECK_OK(dataflow_analysis->Verify()); @@ -937,9 +935,7 @@ Status HloDataflowAnalysis::Verify() const { for (const HloValue* value : values()) { for (const HloPosition& position : value->positions()) { const HloValueSet& value_set = GetValueSet(position); - TF_RET_CHECK(std::find(value_set.values().begin(), - value_set.values().end(), - value) != value_set.values().end()) + TF_RET_CHECK(absl::c_linear_search(value_set.values(), value)) << "Value set at position " << position << " does not contain value " << value->ToShortString(); } @@ -954,9 +950,7 @@ Status HloDataflowAnalysis::Verify() const { const HloValueSet& value_set = pair.second; const HloPosition position{instruction, index}; for (const HloValue* value : value_set.values()) { - TF_RET_CHECK(std::find(value->positions().begin(), - value->positions().end(), - position) != value->positions().end()) + TF_RET_CHECK(absl::c_linear_search(value->positions(), position)) << "Value set at position " << position << " unexpectedly contains value " << value->ToShortString(); } @@ -1041,11 +1035,10 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Check if one operand of kAdd fused root is kDot or kConvolution. auto* add = user->fused_expression_root(); auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot; - }); + absl::c_find_if(add->operands(), [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); if (add_operand_it == add->operands().end()) { return false; } @@ -1100,16 +1093,15 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // *) The root instruction of the called computation is element-wise on // 'operand'. const bool found_caller_use = - std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { + absl::c_find_if(uses, [user](const HloUse& use) { return use.instruction == user; }) != uses.end(); auto* callee_root = user->to_apply()->root_instruction(); const bool found_elementwise_callee_use = - std::find_if( - uses.begin(), uses.end(), [callee_root](const HloUse& use) { - return use.instruction == callee_root && - callee_root->IsElementwiseOnOperand(use.operand_number); - }) != uses.end(); + absl::c_find_if(uses, [callee_root](const HloUse& use) { + return use.instruction == callee_root && + callee_root->IsElementwiseOnOperand(use.operand_number); + }) != uses.end(); return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index f7a1f19a6f52befd58a405d0e406d7d0d37a8e57..768e3afb3b80698061b62c4aadef09c20e2f286c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -73,8 +74,8 @@ class HloDataflowAnalysisTest : public HloTestBase, bool InstructionsMayInterfere(const HloOrdering& ordering, const HloInstruction* a, const HloInstruction* b) { - EXPECT_FALSE(ShapeUtil::IsTuple(a->shape())); - EXPECT_FALSE(ShapeUtil::IsTuple(b->shape())); + EXPECT_FALSE(a->shape().IsTuple()); + EXPECT_FALSE(b->shape().IsTuple()); return ordering.MayInterfere(analysis_->GetValueDefinedAt(a), analysis_->GetValueDefinedAt(b), *analysis_); } @@ -1882,8 +1883,8 @@ TEST_P(HloDataflowAnalysisTest, AddDependency) { HloModule AddDependency ENTRY %AddDependency (p: f32[3]) -> f32[3] { %p = f32[3] parameter(0) - %token = token[] after-all() - ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token) + %token0 = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token0) } )"; TF_ASSERT_OK_AND_ASSIGN( @@ -1901,9 +1902,9 @@ ENTRY %AddDependency (p: f32[3]) -> f32[3] { EXPECT_FALSE(analysis->ValueIsDefinedAt(root)); } -INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, - HloDataflowAnalysisTest, - ::testing::Values(false, true)); +INSTANTIATE_TEST_SUITE_P(HloDataflowAnalysisInstantiation, + HloDataflowAnalysisTest, + ::testing::Values(false, true)); class HloDataflowAnalysisTestBase : public HloTestBase { protected: @@ -1970,12 +1971,13 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, + std::initializer_list({starts}))); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -2012,12 +2014,13 @@ TEST_F(DoesNotUseOperandBufferTest, IndirectUses) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, + std::initializer_list({starts}))); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -2150,17 +2153,17 @@ TEST_F(CanShareOperandBufferWithUserTest, auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "param0")); - auto index = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0}))); - auto ds = builder.AddInstruction( - HloInstruction::CreateDynamicSlice(slice_shape, param, index, {1, 2, 2})); + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + slice_shape, param, {zero, zero}, {1, 2, 2})); - auto dus = builder.AddInstruction( - HloInstruction::CreateDynamicUpdateSlice(data_shape, param, ds, index)); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, param, ds, {zero, zero})); BuildModule(builder.Build()); auto fusion = computation_->CreateFusionInstruction( - {dus, ds, index}, HloInstruction::FusionKind::kLoop); + {dus, ds, zero}, HloInstruction::FusionKind::kLoop); RunAnalysis(); EXPECT_TRUE( @@ -2219,12 +2222,13 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, + std::initializer_list({starts}))); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -2259,12 +2263,13 @@ TEST_F(CanShareOperandBufferWithUserTest, // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape_bf16, convert1, update, starts)); + data_shape_bf16, convert1, update, + std::initializer_list({starts}))); auto convert2 = builder.AddInstruction( HloInstruction::CreateConvert(data_shape, dynamic_update_slice)); @@ -2290,10 +2295,13 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { HloInstruction::CreateParameter(0, data_shape, "data")); auto update = builder.AddInstruction( HloInstruction::CreateParameter(1, update_shape, "update")); - auto starts = builder.AddInstruction( - HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto start0 = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "start0")); + auto start1 = builder.AddInstruction( + HloInstruction::CreateParameter(3, starts_shape, "start1")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, data, update, starts)); + data_shape, data, update, {start0, start1})); BuildModuleAndRunAnalysis(builder.Build()); @@ -2304,7 +2312,9 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { EXPECT_FALSE( dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); EXPECT_FALSE( - dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); + dataflow_analysis_->CanShareOperandBufferWithUser(start0, {}, dus, {})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(start1, {}, dus, {})); } TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { @@ -2347,14 +2357,17 @@ TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto builder = HloComputation::Builder(TestName()); + module_ = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - auto sort = - builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, MakeSortHlo(keys_shape, {keys}, -1, /*is_stable=*/false, + &builder, module_.get())); - BuildModuleAndRunAnalysis(builder.Build()); + computation_ = module_->AddEntryComputation(builder.Build()); + RunAnalysis(); EXPECT_TRUE( dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {})); @@ -2362,6 +2375,7 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto builder = HloComputation::Builder(TestName()); + module_ = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); Shape values_shape = ShapeUtil::MakeShape(F32, {8}); @@ -2369,11 +2383,14 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { HloInstruction::CreateParameter(0, keys_shape, "keys")); auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); - auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, - {values})); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, + MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}), + {keys, values}, 0, /*is_stable=*/false, &builder, + module_.get())); - BuildModuleAndRunAnalysis(builder.Build()); + computation_ = module_->AddEntryComputation(builder.Build()); + RunAnalysis(); // The buffer for the keys can be shared with the first tuple entry. EXPECT_TRUE( diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 7d35e251ca21951036336ff1a1eb4aabc87bc5ca..a5a11f09cf4f857b992e5ede3a9dbc5a937ce722 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -65,7 +66,7 @@ StatusOr HloDCE::Run(HloModule* module) { // Now DCE HloComputations. First, collect the computations that are // referenced by some remaining instruction. - std::unordered_set live_computations; + absl::flat_hash_set live_computations; if (HloComputation* entry_computation = module->entry_computation()) { live_computations.insert(entry_computation); } @@ -79,7 +80,7 @@ StatusOr HloDCE::Run(HloModule* module) { // Remove dead computations. for (auto* computation : module->MakeComputationPostOrder()) { - if (live_computations.count(computation) == 0) { + if (!live_computations.contains(computation)) { TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation)); changed = true; } diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 1fa4259a3e42286cbc911907eea563e6ca6f8611..b5d72b386f89568cc3066b2e497be98428d1ed0c 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -43,9 +43,7 @@ class HloDceTest : public HloTestBase { // Returns whether the given instruction exists in the given computation. bool HasInstruction(const HloComputation& computation, const HloInstruction* instruction) { - return std::find(computation.instructions().begin(), - computation.instructions().end(), - instruction) != computation.instructions().end(); + return absl::c_linear_search(computation.instructions(), instruction); } }; diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index c6d02f9f67bb599e496d20fc2acf2e627ed54438..7cdb7f6bdf26241cda4fabbb5ccaf6e6f7de39ce 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -230,10 +230,10 @@ HloDomainMap::MakeNonDomainInstructions( } } // sort instructions according to instructions_order - std::sort(instructions.begin(), instructions.end(), - [&instructions_order](HloInstruction* a, HloInstruction* b) { - return instructions_order.at(a) < instructions_order.at(b); - }); + absl::c_sort(instructions, + [&instructions_order](HloInstruction* a, HloInstruction* b) { + return instructions_order.at(a) < instructions_order.at(b); + }); return instructions; } diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index acdb42128e3d9a1fb912a466c9c2c3cbbe3d3f83..fd4fb0246d8d42ab7329c05dc23e386303cdce3c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -195,10 +195,10 @@ HloModule Module ENTRY entry { p0 = (f32[4]) parameter(0) a = f32[4] get-tuple-element(p0), index=0 - token = token[] after-all() - b = (f32[4], u32[], token[]) send(a, token), channel_id=1, sharding={maximal device=0} + token0 = token[] after-all() + b = (f32[4], u32[], token[]) send(a, token0), channel_id=1, sharding={maximal device=0} c = token[] send-done(b), channel_id=1, sharding={maximal device=0} - d = (f32[4], u32[], token[]) recv(token), channel_id=2, sharding={maximal device=0} + d = (f32[4], u32[], token[]) recv(token0), channel_id=2, sharding={maximal device=0} e = (f32[4], token[]) recv-done(d), channel_id=2, sharding={maximal device=0} e_element = f32[4] get-tuple-element(e), index=0, sharding={maximal device=0} f = f32[4] add(a, e_element) @@ -235,12 +235,12 @@ TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { HloModule Module ENTRY entry { - token = token[] after-all(), sharding={maximal device=-1} - a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=-1} + token0 = token[] after-all(), sharding={maximal device=-1} + a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=-1} b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=-1} b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1} c = f32[4] add(b_element, b_element), sharding={maximal device=-1} - d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=-1} + d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=-1} ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1} } )"; @@ -259,12 +259,12 @@ TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) { HloModule Module ENTRY entry { - token = token[] after-all(), sharding={maximal device=0} - a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=0} + token0 = token[] after-all(), sharding={maximal device=0} + a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=0} b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=0} b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=0} c = f32[4] add(b_element, b_element) - d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=0} + d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=0} ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=0} } )"; @@ -344,8 +344,8 @@ TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { HloModule Module ENTRY entry { - token = token[] after-all() - infeed = ((f32[4], f32[4]), token[]) infeed(token), + token0 = token[] after-all() + infeed = ((f32[4], f32[4]), token[]) infeed(token0), sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}} infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0, sharding={{maximal device=1}, {maximal device=0}} diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index 72006e17e7e7ec09b62e88d05b695ec9f4c49647..7d6b86056af3fc2128fe1642bbfa0ca6f9ef1da0 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -68,7 +68,7 @@ Shape GetConvertedTupleShape(const Shape& shape, PrimitiveType from_type, std::vector new_tuple_subshapes; for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { Shape subshape = ShapeUtil::GetTupleElementShape(shape, i); - CHECK(!ShapeUtil::IsTuple(subshape)); + CHECK(!subshape.IsTuple()); if (subshape.element_type() == from_type) { subshape = ShapeUtil::ChangeElementType(subshape, to_type); } @@ -92,7 +92,7 @@ HloInstruction* ConvertTupleElements(HloInstruction* hlo, HloInstruction* element = computation->AddInstruction( HloInstruction::CreateGetTupleElement(ele_shape, hlo, i)); const Shape& to_ele_shape = ShapeUtil::GetTupleElementShape(to_shape, i); - CHECK(!ShapeUtil::IsTuple(ele_shape)); + CHECK(!ele_shape.IsTuple()); if (ele_shape.element_type() != to_ele_shape.element_type()) { element = computation->AddInstruction( HloInstruction::CreateConvert(to_ele_shape, element)); @@ -127,6 +127,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { // These are ops where it does not make sense to convert them. if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert || + opcode == HloOpcode::kBitcastConvert || opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) { continue; @@ -141,12 +142,11 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { // These are ops with embedded computations where it suffices to convert // the embedded computations instead of converting the ops themselves. if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || - opcode == HloOpcode::kCrossReplicaSum || - opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || - opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || - opcode == HloOpcode::kScatter || + opcode == HloOpcode::kAllReduce || opcode == HloOpcode::kFusion || + opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || + opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kScatter || opcode == HloOpcode::kSelectAndScatter || - opcode == HloOpcode::kConditional) { + opcode == HloOpcode::kSort || opcode == HloOpcode::kConditional) { continue; } TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); @@ -191,7 +191,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); new_hlo = ToElementType(new_hlo, eliminate_type_); - } else if (ShapeUtil::IsTuple(hlo->shape())) { + } else if (hlo->shape().IsTuple()) { Shape old_shape = hlo->shape(); Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_, replace_with_type_); diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index c170e36c73ad2bef830e528de3ec72d38683d888..4171f738620dbf545e5883b8c26169fae4b93643 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -28,15 +28,7 @@ using ::testing::Eq; using ::testing::Not; using ::testing::ResultOf; -class HloElementTypeConverterTest : public HloTestBase { - public: - std::unique_ptr CreateModuleFromHloString( - const string& hlo_string) { - return HloRunner::CreateModuleFromString(hlo_string, - GetDebugOptionsForTest()) - .ValueOrDie(); - } -}; +using HloElementTypeConverterTest = HloTestBase; TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) { const string& hlo_string = R"( @@ -47,7 +39,7 @@ TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) { custom_call_target="foo" } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_FALSE(converted); @@ -57,13 +49,13 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) { const string& hlo_string = R"( HloModule InfeedOutfeed ENTRY RoundTrip16MiBR1.v2 { - token = token[] after-all() - infeed = (bf16[4]{0}, token[]) infeed(token) + token0 = token[] after-all() + infeed = (bf16[4]{0}, token[]) infeed(token0) ROOT infeed.data = bf16[4]{0} get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) + outfeed = token[] outfeed(infeed.data, token0) } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_FALSE(converted); @@ -73,17 +65,16 @@ TEST_F(HloElementTypeConverterTest, OperationsInNestedTuplesConverted) { const string& hlo_string = R"( HloModule NestedTuples ENTRY NestedTuples.v5 { - constant.4 = bf16[] constant(42) constant.2 = f32[2]{0} constant({1, 2}) - constant.3 = bf16[] constant(42) - add = bf16[] add(constant.2, constant.3) - tuple = (f32[2]{0}, bf16[]) tuple(constant.2, add) + constant.3 = bf16[2]{0} constant({42, 42}) + add = bf16[2]{0} add(constant.2, constant.3) + tuple = (f32[2]{0}, bf16[2]{0}) tuple(constant.2, add) constant.5 = bf16[2]{0} constant({22, 44}) - ROOT tuple.1 = ((f32[2]{0}, bf16[]), bf16[2]{0}) tuple(tuple, constant.5) + ROOT tuple.1 = ((f32[2]{0}, bf16[2]{0}), bf16[2]{0}) tuple(tuple, constant.5) } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_TRUE(converted); @@ -96,13 +87,13 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) { const string& hlo_string = R"( HloModule BatchNormGrad ENTRY BatchNormGrad.v6 { - constant.4 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/ + constant.4 = bf16[2,2,2,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {0}, {0} }, { /*i1=1*/ {0}, {0} } }, { /*i0=1*/ { /*i1=0*/ {0}, {0} }, { /*i1=1*/ {0}, {0} } } }) constant.5 = bf16[2]{0} constant({1, 1}) constant.6 = bf16[2]{0} constant({0, 0}) constant.7 = bf16[2]{0} constant({1, 1}) - constant.8 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/ + constant.8 = bf16[2,2,2,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} } }, { /*i0=1*/ { /*i1=0*/ {5}, {6} }, { /*i1=1*/ {7}, {8} } } }) ROOT batch-norm-grad = (bf16[2,2,2,1]{3,2,1,0}, bf16[2]{0}, bf16[2]{0}) @@ -111,7 +102,7 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) { } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_TRUE(converted); @@ -135,7 +126,7 @@ ENTRY main { ROOT rng = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), distribution=rng_uniform } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_TRUE(converted); @@ -161,7 +152,7 @@ ENTRY main { ROOT rng1 = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), control-predecessors={%rng0}, distribution=rng_uniform } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); @@ -185,5 +176,19 @@ ENTRY main { EXPECT_THAT(rng1->control_predecessors(), ElementsAre(rng0)); } +TEST_F(HloElementTypeConverterTest, BitcastConvertIsUnmodified) { + const string& hlo_string = R"( + HloModule test + + ENTRY test { + p = bf16[] parameter(0) + ROOT c = u16[] bitcast-convert(p) + })"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + HloElementTypeConverter converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, RunHloPass(&converter, module.get())); + EXPECT_FALSE(converted); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 3a7652a8dc856b23c8988c4676916c8199e78860..4d6487700b24cfd3b89aece58e5ad6d7bb43a800 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include #include +#include #include #include -#include #include #include "absl/algorithm/container.h" @@ -29,10 +29,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -135,8 +136,44 @@ StatusOr Compare(const Shape& shape, HloOpcode opcode, return std::move(result); } +template <> +StatusOr Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { + std::function compare_op; + switch (opcode) { + case HloOpcode::kEq: + compare_op = [](complex128 lhs_el, complex128 rhs_el) { + return lhs_el == rhs_el; + }; + break; + case HloOpcode::kNe: + compare_op = [](complex128 lhs_el, complex128 rhs_el) { + return lhs_el != rhs_el; + }; + break; + default: + LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " + << HloOpcodeString(opcode); + } + + Literal result(shape); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); + + return std::move(result); +} + } // namespace +// Note that unsupported types by the typed visitor does not necessarily imply +// the non-typed HloEvaluator (parent evaluator) would not support them either +// in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent +// type-agnostic evaluator will be able to accept Tuple primitive type, whereas +// HloEvaluatorTypedVisitor cannot. HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { typed_visitors_[PRED] = @@ -144,22 +181,14 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) typed_visitors_[U8] = absl::make_unique>(this); typed_visitors_[U16] = - absl::make_unique([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "U16."); - }); + absl::make_unique>(this); typed_visitors_[U32] = absl::make_unique>(this); typed_visitors_[U64] = absl::make_unique>(this); typed_visitors_[S8] = absl::make_unique>(this); typed_visitors_[S16] = - absl::make_unique([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "S16."); - }); + absl::make_unique>(this); typed_visitors_[S32] = absl::make_unique>(this); typed_visitors_[S64] = @@ -172,6 +201,8 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) absl::make_unique>(this); typed_visitors_[C64] = absl::make_unique>(this); + typed_visitors_[C128] = + absl::make_unique>(this); // Most of the evaluator computations we use don't support BF16 (e.g., // std::ceil, std::tanh). To make evaluator work with BF16, we set all @@ -197,65 +228,30 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) }); } -template -StatusOr HloEvaluator::Evaluate( - const HloModule& module, absl::Span arg_literals) { - XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); - - evaluated_.clear(); - arg_literals_.clear(); - for (const auto& literal_ptr : arg_literals) { - arg_literals_.push_back(&*literal_ptr); - } - - TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); - - return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()) - .Clone(); -} - -template <> -StatusOr HloEvaluator::Evaluate( - const HloModule& module, absl::Span arg_literals) { - std::vector arg_literal_ptrs; - for (const auto& literal_ptr : arg_literals) { - arg_literal_ptrs.push_back(&literal_ptr); - } - return Evaluate(module, arg_literal_ptrs); -} - -template StatusOr HloEvaluator::Evaluate( const HloComputation& computation, - absl::Span arg_literals) { + absl::Span arg_literals) { CHECK(computation.parent() != nullptr); XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); - evaluated_.clear(); - arg_literals_.clear(); - for (const auto& literal_ptr : arg_literals) { - arg_literals_.push_back(&*literal_ptr); + if (arg_literals.size() != computation.num_parameters()) { + return InvalidArgument( + "Expected %d argument%s, but got %d.", computation.num_parameters(), + computation.num_parameters() == 1 ? "" : "s", arg_literals.size()); } - - TF_RETURN_IF_ERROR(computation.Accept(this)); - return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); -} - -template <> -StatusOr HloEvaluator::Evaluate( - const HloComputation& computation, absl::Span arg_literals) { - std::vector arg_literal_ptrs; - for (const auto& literal_ptr : arg_literals) { - arg_literal_ptrs.push_back(&literal_ptr); + for (int64 i = 0; i < arg_literals.size(); ++i) { + const auto& computation_shape = + computation.parameter_instruction(i)->shape(); + const auto& arg_shape = arg_literals[i]->shape(); + if (!ShapeUtil::Equal(computation_shape, arg_shape)) { + return InvalidArgument( + "Shape mismatch at parameter %d. Computation expected %s, but arg " + "was %s.", + i, ShapeUtil::HumanStringWithLayout(computation_shape), + ShapeUtil::HumanString(arg_shape)); + } } - return Evaluate(computation, arg_literal_ptrs); -} - -template -StatusOr HloEvaluator::Evaluate( - HloInstruction* instruction, absl::Span arg_literals) { - TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); evaluated_.clear(); arg_literals_.clear(); @@ -263,33 +259,20 @@ StatusOr HloEvaluator::Evaluate( arg_literals_.push_back(&*literal_ptr); } - // Evaluate operands of Parameter type against the input literals which - // caches the evaluated literal results. - for (const auto operand : instruction->operands()) { - if (operand->opcode() == HloOpcode::kParameter) { - const Literal* input_literal = arg_literals_[operand->parameter_number()]; - VLOG(2) << "Parameter operand evaluated to: " - << input_literal->ToString(); - TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); - - evaluated_[operand] = input_literal->Clone(); - } + // Re-seed RNG, either from the configuration's seed or a monotonic + // per-evaluator seed (which prevents two evaluators from returning the same + // random sequence). + if (computation.parent()->config().seed()) { + seed_ = computation.parent()->config().seed(); + } else { + // Start global_seed at a (true) random value. + static std::atomic global_seed{std::random_device()()}; + seed_ = global_seed.fetch_add(1); } + engine_.seed(seed_); - TF_RETURN_IF_ERROR(Preprocess(instruction)); - TF_RETURN_IF_ERROR(instruction->Visit(this)); - TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).Clone(); -} - -template <> -StatusOr HloEvaluator::Evaluate( - HloInstruction* instruction, absl::Span arg_literals) { - std::vector arg_literal_ptrs; - for (const auto& literal : arg_literals) { - arg_literal_ptrs.push_back(&literal); - } - return Evaluate(instruction, arg_literal_ptrs); + TF_RETURN_IF_ERROR(computation.Accept(this)); + return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); } StatusOr HloEvaluator::Evaluate(HloInstruction* instruction) { @@ -407,16 +390,45 @@ Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status HloEvaluator::HandleGetDimensionSize( + HloInstruction* get_dimension_size) { + HloInstruction* operand = get_dimension_size->mutable_operand(0); + int64 dim = get_dimension_size->dimension(); + if (dynamic_dimension_inference_ == nullptr) { + return InvalidArgument( + "Evaluator cannot evaluate get_dimension_size without " + "set_dynamic_dimension_inference."); + } + HloInstruction* dynamic_size = + dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim); + if (dynamic_size != nullptr) { + evaluated_[get_dimension_size] = + GetEvaluatedLiteralFor(dynamic_size).Clone(); + return Status::OK(); + } + + const Shape& shape = get_dimension_size->operand(0)->shape(); + Literal output(ShapeUtil::MakeShape(U32, {})); + output.PopulateWithValue( + static_cast(shape.dimensions(get_dimension_size->dimension()))); + evaluated_[get_dimension_size] = std::move(output); + return Status::OK(); +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { + // Nothing to do other than sanity checks. Parameters' values are stored in + // arg_literals_. CHECK_LT(parameter->parameter_number(), arg_literals_.size()); + +#ifndef NDEBUG const Literal* input_literal = arg_literals_[parameter->parameter_number()]; VLOG(2) << "Parameter evaluated to: " << input_literal->ToString(); DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())) << "parameter shape is: " << ShapeUtil::HumanString(parameter->shape()) << ", but input literal shape is: " << ShapeUtil::HumanString(input_literal->shape()); +#endif - evaluated_[parameter] = input_literal->Clone(); return Status::OK(); } @@ -441,8 +453,8 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { // The result concatenate dimension is going to be the sum of all // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); - CHECK(ShapeUtil::IsArray(reference_shape)); - const int64 rank = ShapeUtil::Rank(reference_shape); + CHECK(reference_shape.IsArray()); + const int64 rank = reference_shape.rank(); const int64 concat_dim = concatenate->dimensions()[0]; CHECK_GE(concat_dim, 0); CHECK_LT(concat_dim, rank); @@ -452,7 +464,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (int64 i = 1; i < operands.size(); ++i) { const Shape& operand_shape = operands[i]->shape(); - CHECK(ShapeUtil::IsArray(operand_shape)); + CHECK(operand_shape.IsArray()); // Accumulate the concat dimension from all tensors taking part to the // operation. concat_dimensions[concat_dim] += @@ -479,15 +491,52 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { auto operand = is_finite->operand(0); - if (!ShapeUtil::ElementIsFloating(operand->shape())) { - return InvalidArgument( - "expected element type in shape to be float for IsFinite op, got: %s", - PrimitiveType_Name(operand->shape().element_type())); - } + auto elem_ty = operand->shape().element_type(); + switch (elem_ty) { + case PRED: + case TUPLE: + case OPAQUE: + case TOKEN: + case S8: + case S16: + case S32: + case S64: + case U8: + case U16: + case U32: + case U64: + case C64: + case C128: + // Explicitly enumerate all types in this switch so that when we add a new + // type, we'll get a compile error here. + case PRIMITIVE_TYPE_INVALID: + case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_: + case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_: + return InvalidArgument( + "expected element type in shape to be floating point, but " + "got: %s", + PrimitiveType_Name(elem_ty)); - switch (operand->shape().element_type()) { - case F16: - return Unimplemented("unhandled primitive type: F16."); + case F16: { + auto result_or = ElementWiseUnaryOpImpl( + is_finite, + [](Eigen::half elem_operand) { + return std::isfinite(static_cast(elem_operand)); + }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); + break; + } + case BF16: { + auto result_or = ElementWiseUnaryOpImpl( + is_finite, + [](bfloat16 elem_operand) { + return std::isfinite(static_cast(elem_operand)); + }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); + break; + } case F32: { auto result_or = ElementWiseUnaryOpImpl( is_finite, @@ -504,9 +553,6 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); break; } - default: - LOG(FATAL) << "HandleIsFinite: unknown/unhandled primitive type: " - << PrimitiveType_Name(operand->shape().element_type()); } return Status::OK(); @@ -529,6 +575,13 @@ Status HloEvaluator::HandleReal(HloInstruction* real) { TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); break; } + case C128: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](complex128 elem_operand) { return std::real(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } case F16: { auto result_or = ElementWiseUnaryOpImpl( real, [](Eigen::half elem_operand) { return elem_operand; }, @@ -559,11 +612,61 @@ Status HloEvaluator::HandleReal(HloInstruction* real) { } Status HloEvaluator::HandleImag(HloInstruction* imag) { - auto result_or = ElementWiseUnaryOpImpl( - imag, [](complex64 elem_operand) { return std::imag(elem_operand); }, - GetEvaluatedLiteralFor(imag->operand(0))); + auto operand = imag->operand(0); + switch (operand->shape().element_type()) { + case C64: { + auto result_or = ElementWiseUnaryOpImpl( + imag, [](complex64 elem_operand) { return std::imag(elem_operand); }, + GetEvaluatedLiteralFor(imag->operand(0))); + + TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + break; + } + case C128: { + auto result_or = ElementWiseUnaryOpImpl( + imag, [](complex128 elem_operand) { return std::imag(elem_operand); }, + GetEvaluatedLiteralFor(imag->operand(0))); - TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + break; + } + default: + LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: " + << PrimitiveType_Name(operand->shape().element_type()); + } + + return Status::OK(); +} + +Status HloEvaluator::HandleComplex(HloInstruction* complex) { + const Literal& real = GetEvaluatedLiteralFor(complex->operand(0)); + const Literal& imag = GetEvaluatedLiteralFor(complex->operand(1)); + TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape())); + + Literal result(complex->shape()); + switch (complex->shape().element_type()) { + case C64: { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return std::complex(real.Get(multi_index), + imag.Get(multi_index)); + })); + break; + } + case C128: { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return std::complex(real.Get(multi_index), + imag.Get(multi_index)); + })); + break; + } + default: + LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: " + << PrimitiveType_Name(complex->shape().element_type()); + } + + evaluated_[complex] = std::move(result); return Status::OK(); } @@ -600,8 +703,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { evaluated_[compare], Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; - case U16: - return Unimplemented("unhandled primitive type: U16."); + case U16: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; case U32: { TF_ASSIGN_OR_RETURN( evaluated_[compare], @@ -617,8 +723,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { evaluated_[compare], Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; - case S16: - return Unimplemented("unhandled primitive type: S16."); + case S16: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; case S32: { TF_ASSIGN_OR_RETURN( evaluated_[compare], @@ -629,8 +738,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { evaluated_[compare], Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; - case F16: - return Unimplemented("unhandled primitive type: F16."); + case F16: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; case BF16: { TF_ASSIGN_OR_RETURN(evaluated_[compare], Compare(compare->shape(), opcode, @@ -651,6 +763,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; + case C128: { + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), opcode, + lhs_literal, rhs_literal)); + } break; default: LOG(FATAL) << "HandleCompare: unknown primitive type: " << PrimitiveType_Name(lhs->shape().element_type()); @@ -1032,11 +1149,9 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0)); - TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(operand.shape())) + TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank()) << "broadcast dimensions is of size: " << broadcast->dimensions().size() - << " and rank of operand_to_broadcast is: " - << ShapeUtil::Rank(operand.shape()); + << " and rank of operand_to_broadcast is: " << operand.shape().rank(); // Checks that operand's dimensions are the same as the broadcast's // dimensions along the dimensions to be broadcasted. for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { @@ -1109,9 +1224,10 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { } HloEvaluator embedded_evaluator; - Literal result = - embedded_evaluator.Evaluate(*computation, arg_literals) - .ConsumeValueOrDie(); + embedded_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); + Literal result = embedded_evaluator.Evaluate(*computation, arg_literals) + .ConsumeValueOrDie(); evaluated_[call] = std::move(result); return Status::OK(); @@ -1127,7 +1243,9 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { fusion->fused_instructions_computation()->Clone( /*suffix=*/"clone_with_layout", &context); for (auto* instruction : cloned_fused_computation->instructions()) { - LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); + if (!LayoutUtil::HasLayout(instruction->shape())) { + LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); + } } auto readded_computation = empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation)); @@ -1141,9 +1259,10 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { } HloEvaluator embedded_evaluator; + embedded_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); Literal result = - embedded_evaluator - .Evaluate(*readded_computation, arg_literals) + embedded_evaluator.Evaluate(*readded_computation, arg_literals) .ConsumeValueOrDie(); evaluated_[fusion] = std::move(result); @@ -1161,16 +1280,16 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* false_computation = conditional->false_computation(); HloEvaluator embedded_evaluator; + embedded_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); Literal result; if (pred.Get({})) { - result = embedded_evaluator - .Evaluate(*true_computation, - {&true_computation_arg}) - .ConsumeValueOrDie(); + result = + embedded_evaluator.Evaluate(*true_computation, {&true_computation_arg}) + .ConsumeValueOrDie(); } else { result = embedded_evaluator - .Evaluate(*false_computation, - {&false_computation_arg}) + .Evaluate(*false_computation, {&false_computation_arg}) .ConsumeValueOrDie(); } @@ -1217,18 +1336,21 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { bool keep_going = true; int64 iteration_count = 0; HloEvaluator cond_evaluator(max_loop_iterations_); + cond_evaluator.set_dynamic_dimension_inference(dynamic_dimension_inference_); HloEvaluator loop_body_evaluator(max_loop_iterations_); + loop_body_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); while (keep_going) { if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { return InvalidArgument("Loop %s exceeded loop iteration limit (%d).", while_hlo->name(), max_loop_iterations_); } TF_ASSIGN_OR_RETURN(auto cond_val, - cond_evaluator.Evaluate(*cond_comp, {&lcv})); + cond_evaluator.Evaluate(*cond_comp, {&lcv})); keep_going = cond_val.GetFirstElement(); if (keep_going) { - TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate( - *body_comp, {&lcv})); + TF_ASSIGN_OR_RETURN(auto body_val, + loop_body_evaluator.Evaluate(*body_comp, {&lcv})); VLOG(3) << "Loop iteration result: " << body_val.ToString(); lcv = std::move(body_val); cond_evaluator.ResetVisitStates(); @@ -1239,173 +1361,216 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { return Status::OK(); } -// Key-value sort is a special snowflake: it's templated on two different -// element types, one for the keys, and one for the values. Jump through some -// hoops to make this work. namespace { -template -StatusOr EvaluateSortInternal(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { - auto rank = ShapeUtil::Rank(keys_literal.shape()); - TF_RET_CHECK( - ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) - << "Sort keys and values must have the same dimensions"; - TF_RET_CHECK(sort->operand_count() >= 2) << "Expected key-value sort"; - // We need to sort an array of keys and an array of values, where the - // sorted order of the values is determined by the keys. The simplest(?) - // way to do this is to go to an array-of-pairs representation, sort the - // array using the keys, and then go back to pair-of-arrays. - VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); - VLOG(3) << "HandleSort values_literal: " << values_literal.ToString(); - - if (rank == 0) { - // Nothing to sort. - return LiteralUtil::MakeTuple({&keys_literal, &values_literal}); +template +Literal ExtractLiteralFromIndexPositions(const Literal& from, + absl::Span indices, + bool extract_as_scalar) { + if (extract_as_scalar) { + return LiteralUtil::CreateR0(from.Get({indices[0]})); } + // We use a InlinedVector here because we need to convert it to an + // absl::Span later, and this would not work with std::vector. + absl::InlinedVector values; + for (int64 index : indices) { + values.push_back(from.Get({index})); + } + return LiteralUtil::CreateR1(values); +} - Literal keys_result_literal(keys_literal.shape()); - Literal values_result_literal(values_literal.shape()); +StatusOr ExtractFromIndexPositions(const Literal& from, + absl::Span indices, + bool extract_as_scalar = false) { + if (extract_as_scalar) { + CHECK_EQ(indices.size(), 1); + } + PrimitiveType type = from.shape().element_type(); + switch (type) { + case PRED: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case U8: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case S8: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case BF16: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case F16: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case U16: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case S16: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case F32: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case U32: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case S32: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case F64: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case U64: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case S64: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + default: + return InvalidArgument("Unsupported type for Sort: %s", + PrimitiveType_Name(type)); + } +} +} // namespace + +Status HloEvaluator::HandleSort(HloInstruction* sort) { + TF_RET_CHECK(sort->operand_count() >= 1) + << "Expected at least 1 operand for sort"; + for (int64 i = 1; i < sort->operand_count(); ++i) { + TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(i)->shape())) + << "All Sort operands must have the same dimensions"; + } + + if (VLOG_IS_ON(3)) { + for (int64 i = 0; i < sort->operand_count(); ++i) { + VLOG(3) << "HandleSort operand " << i << " literal: " + << GetEvaluatedLiteralFor(sort->operand(i)).ToString(); + } + } + Shape key_shape = sort->operand(0)->shape(); + auto rank = key_shape.rank(); + std::vector result_literals; + result_literals.reserve(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + result_literals.emplace_back(sort->operand(i)->shape()); + } std::vector zero_base(rank, 0); std::vector increment(rank, 1); int64 sort_dim = sort->dimensions(0); - int64 sort_dim_elements = keys_literal.shape().dimensions(sort_dim); + int64 sort_dim_elements = key_shape.dimensions(sort_dim); increment[sort_dim] = sort_dim_elements; + HloEvaluator embedded_evaluator(max_loop_iterations_); // Iterate through each dimension except 'sort_dim'. TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - keys_literal.shape(), zero_base, - AsInt64Slice(keys_literal.shape().dimensions()), increment, + key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment, [&](absl::Span indices) -> StatusOr { - // Extract a slice from the keys and values literals that correspond to + // Extract a slice from each operand literal that corresponds to // exactly the row in dimension 'sort_dim'. std::vector limit_indices(indices.begin(), indices.end()); - std::for_each(limit_indices.begin(), limit_indices.end(), - [](int64& index) { ++index; }); + absl::c_for_each(limit_indices, [](int64& index) { ++index; }); limit_indices[sort_dim] = sort_dim_elements; - TF_ASSIGN_OR_RETURN(auto keys_to_sort, - keys_literal.Slice(indices, limit_indices) - .Reshape({sort_dim_elements})); - const auto& keys_data = keys_to_sort.data(); - TF_ASSIGN_OR_RETURN(auto values_to_sort, - values_literal.Slice(indices, limit_indices) - .Reshape({sort_dim_elements})); - const auto& values_data = values_to_sort.data(); - using kv_pair = std::pair; - std::vector key_value_vector; - key_value_vector.reserve(keys_data.size()); - for (int i = 0; i < keys_data.size(); ++i) { - key_value_vector.push_back( - std::make_pair(keys_data[i], values_data[i])); + std::vector literals_to_sort; + literals_to_sort.reserve(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + TF_ASSIGN_OR_RETURN(auto literal_to_sort, + GetEvaluatedLiteralFor(sort->operand(i)) + .Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + literals_to_sort.push_back(std::move(literal_to_sort)); + } + std::vector indices_to_sort(sort_dim_elements); + std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0); + Status compare_status = Status::OK(); + auto comparator = [sort, &compare_status, &embedded_evaluator, + &literals_to_sort](int64 a, int64 b) { + std::vector literals; + literals.reserve(2 * sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a}, + /*extract_as_scalar=*/true); + if (!lhs.ok()) { + compare_status = lhs.status(); + return false; + } + literals.push_back(std::move(lhs.ValueOrDie())); + auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b}, + /*extract_as_scalar=*/true); + if (!rhs.ok()) { + compare_status = rhs.status(); + return false; + } + literals.push_back(std::move(rhs.ValueOrDie())); + } + std::vector literal_ptrs; + absl::c_transform(literals, std::back_inserter(literal_ptrs), + [](const Literal& literal) { return &literal; }); + + auto computed_result = + embedded_evaluator.Evaluate(*sort->to_apply(), literal_ptrs); + // Clear visit states so that we can use the evaluator again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + if (!computed_result.ok()) { + compare_status = computed_result.status(); + return false; + } + return computed_result.ValueOrDie().Get({}); + }; + if (Cast(sort)->is_stable()) { + std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(), + comparator); + } else { + std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator); } - std::stable_sort(key_value_vector.begin(), key_value_vector.end(), - [](const kv_pair& a, const kv_pair& b) { - return SafeLess(a.first, b.first); - }); - std::vector result_keys; - // We use a InlinedVector here because we need to convert it to an - // absl::Span later, and this would not work with std::vector. - absl::InlinedVector result_values; - for (const auto& key_value : key_value_vector) { - result_keys.push_back(key_value.first); - result_values.push_back(key_value.second); + if (!compare_status.ok()) { + return compare_status; } - Literal sorted_keys(ShapeUtil::MakeShape( - keys_literal.shape().element_type(), {sort_dim_elements})); - sorted_keys.PopulateR1(absl::Span(result_keys)); - Literal sorted_values(ShapeUtil::MakeShape( - values_literal.shape().element_type(), {sort_dim_elements})); - sorted_values.PopulateR1(absl::Span(result_values)); std::vector slice_dimensions(rank, 1); slice_dimensions[sort_dim] = sort_dim_elements; std::vector start_indices(rank, 0); - TF_ASSIGN_OR_RETURN(auto sorted_keys_reshaped, - sorted_keys.Reshape(slice_dimensions)); - TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( - sorted_keys_reshaped, start_indices, indices, slice_dimensions)); - TF_ASSIGN_OR_RETURN(auto sorted_values_reshaped, - sorted_values.Reshape(slice_dimensions)); - TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( - sorted_values_reshaped, start_indices, indices, slice_dimensions)); + for (int64 i = 0; i < sort->operand_count(); ++i) { + TF_ASSIGN_OR_RETURN( + Literal sorted_literal, + ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort)); + TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped, + sorted_literal.Reshape(slice_dimensions)); + TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom( + sorted_literal_reshaped, start_indices, indices, + slice_dimensions)); + } return true; })); - Literal result_tuple; - result_tuple = - LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); - VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); - return std::move(result_tuple); -} - -template -StatusOr EvaluateSortCurried(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { - switch (values_literal.shape().element_type()) { - case PRED: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case F32: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case U32: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case S32: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case BF16: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - default: - return InvalidArgument("Unsupported type for Sort"); - } -} - -StatusOr EvaluateSort(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { - switch (sort->operand(0)->shape().element_type()) { - case F32: - return EvaluateSortCurried(sort, keys_literal, values_literal); - case U32: - return EvaluateSortCurried(sort, keys_literal, values_literal); - case S32: - return EvaluateSortCurried(sort, keys_literal, values_literal); - case BF16: - return EvaluateSortCurried(sort, keys_literal, values_literal); - default: - return InvalidArgument("Unsupported type for Sort"); - } -} -} // namespace - -Status HloEvaluator::HandleSort(HloInstruction* sort) { - if (!ShapeUtil::IsTuple(sort->shape())) { - return DefaultAction(sort); + if (sort->operand_count() == 1) { + evaluated_[sort] = std::move(result_literals[0]); } else { - // This is a really stupid work-around for the fact it's hard to support a - // multi-value sort directly, due to the fact we need to template the - // evaluation function on all of the value types. - std::vector sort_results_backing; - for (int64 i = 0; i < sort->operand_count(); ++i) { - auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)), - GetEvaluatedLiteralFor(sort->operand(i))); - if (!result.ok()) { - return result.status(); - } - sort_results_backing.push_back( - std::move(result.ValueOrDie().DecomposeTuple()[1])); - } - std::vector sort_results; - absl::c_transform(sort_results_backing, std::back_inserter(sort_results), + std::vector literal_ptrs; + absl::c_transform(result_literals, std::back_inserter(literal_ptrs), [](const Literal& literal) { return &literal; }); - evaluated_[sort] = LiteralUtil::MakeTuple(sort_results); - return Status::OK(); + + Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs); + VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); + + evaluated_[sort] = std::move(result_tuple); } + return Status::OK(); } Status HloEvaluator::HandleReduce(HloInstruction* reduce) { - if (!ShapeUtil::IsTuple(reduce->shape())) { + if (!reduce->shape().IsTuple()) { return DefaultAction(reduce); } else { auto first_element_type = reduce->shape().tuple_shapes(0).element_type(); @@ -1420,6 +1585,27 @@ Status HloEvaluator::HandleReduce(HloInstruction* reduce) { } } +Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) { + if (!custom_call_handler_) { + // No handler is registered; this means custom-calls are not allowed. + return DefaultAction(custom_call); + } + + // Evaluate input operands so the handler has access to the operand data. + std::vector operands; + operands.reserve(custom_call->operand_count()); + for (const HloInstruction* operand : custom_call->operands()) { + operands.push_back(&GetEvaluatedLiteralFor(operand)); + } + + // Synchronously issue the handler to populate the instruction output literal. + TF_ASSIGN_OR_RETURN( + auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands))); + + evaluated_[custom_call] = std::move(output); + return Status::OK(); +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return ShapeUtil::ValidateShape(hlo->shape()); @@ -1437,16 +1623,46 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { return Status::OK(); } -// Explicit instantiation of templatized Evaluate* methods. -// -template StatusOr HloEvaluator::Evaluate( - const HloModule& module, absl::Span arg_literals); +namespace { +template +std::unique_ptr> MatmulArray2DImpl( + const Array2D& lhs, const Array2D& rhs, + const std::function& impl_fn) { + CHECK_EQ(lhs.width(), rhs.height()); + int m = lhs.height(); + int n = rhs.width(); + int k = lhs.width(); + auto result = absl::make_unique>(m, n); + // Because Eigen is a header-oriented library, make sure that the Eigen code + // is the same as the code used by the CPU backend (otherwise the linker will + // randomly pick *some* definition). + impl_fn( + /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, + k, + /*transpose_lhs=*/0, + /*transpose_rhs=*/0); + return result; +} +} // namespace + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16); +} -template StatusOr HloEvaluator::Evaluate( - const HloComputation& computation, - absl::Span arg_literals); +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32); +} -template StatusOr HloEvaluator::Evaluate( - HloInstruction* instruction, absl::Span arg_literals); +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 45ed8131dc6b71f706fce45d65b206363dd79ac3..357975a131d0c7e63c06e96852468b43d97a37f2 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -16,12 +16,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ +#include #include #include "absl/container/node_hash_map.h" #include "absl/memory/memory.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -42,16 +46,24 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // specified. explicit HloEvaluator(int64 max_loop_iterations = -1); - // Evaluates an HLO module and an array of pointers to literals. - // Returns the evaluated result as a literal if successful. + // Evaluates an HLO module and an array of pointers to literals. Returns the + // evaluated result as a literal if successful. + // // Precondition: The indices of arg_literals correspond to the parameter // numbers of the HLO parameters in the computation. See comment below for an // example. - // `LiteralPtr` accepts either Literal or const Literal* - // type. - template + // + // (Dummy template arg is to reduce the overloading priority of one overload + // so that Evaluate(module, {}) resolves unambiguously.) + StatusOr Evaluate(const HloModule& module, + absl::Span arg_literals) { + return Evaluate(*module.entry_computation(), arg_literals); + } + template StatusOr Evaluate(const HloModule& module, - absl::Span arg_literals); + absl::Span arg_literals) { + return Evaluate(*module.entry_computation(), arg_literals); + } // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -69,29 +81,24 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number // 1 in this computation. The input literals array will then have its first // literal map to Parameter0 and the second map to Parameter1. - // `LiteralPtr` accepts either Literal or const Literal* - // type. - template + // + // (Dummy template arg is to reduce the overloading priority of one overload + // so that Evaluate(module, {}) resolves unambiguously.) + StatusOr Evaluate(const HloComputation& computation, + absl::Span arg_literals); + template StatusOr Evaluate(const HloComputation& computation, - absl::Span arg_literals); - - // Evaluates a single HLO instruction and an array of pointers to literals. - // Return the evaluated result as literal if successful. - // Precondition: - // 1. argument literals correspond to the input instruction's parameters in - // their post-ordering. - // 2. the instruction's operands must be of either Parameter or Constant type. - // `LiteralPtr` accepts either Literal or const Literal* - // type. - template - StatusOr Evaluate(HloInstruction* instruction, - absl::Span arg_literals); - - // Evaluates a single HLO instruction with constant operands. - // Returns the evaluated result as literal if successful. - // Precondition: - // 1. all operands of the input instruction are constants. - // 2. the instruction is not a Parameter operation. + absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& l : arg_literals) { + arg_literal_ptrs.push_back(&l); + } + return Evaluate(computation, arg_literal_ptrs); + } + + // Gets the value of running a single HLO instruction. + // + // All of the operands to this instruction must be constants. StatusOr Evaluate(HloInstruction* instruction); // Same as Evaluate, except returning false on error and accepts an output @@ -119,6 +126,39 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs); + void set_dynamic_dimension_inference( + DynamicDimensionInference* dynamic_dimension_inference) { + dynamic_dimension_inference_ = dynamic_dimension_inference; + } + + // Enable the fast path for certain operations like dot or convolution. + void set_use_fast_path(bool value) { use_fast_path_ = value; } + + // Handles evaluation of a custom-call op. + // Operand literals are provided in |operands| and implementations must + // populate |output| before returning. + using CustomCallHandler = std::function( + HloInstruction* custom_call, absl::Span operands)>; + + // Sets a handler that is called during evaluation for custom-call ops. + // If no handler is defined the default error behavior will occur. The handler + // will be provided evaluated literals for all operands and is expected to + // return an output literal of the appropriate shape. + void set_custom_call_handler( + std::function(HloInstruction* custom_call, + absl::Span operands)> + handler) { + custom_call_handler_ = std::move(handler); + } + + // Returns the result of a matrix multiply `lhs x rhs`. + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this // class. @@ -146,6 +186,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleGetDimensionSize(HloInstruction* get_dimension_size) override; + Status HandleParameter(HloInstruction* parameter) override; Status HandleConstant(HloInstruction* constant) override; @@ -192,16 +234,51 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleImag(HloInstruction* imag) override; + Status HandleComplex(HloInstruction* complex) override; + Status HandleReduce(HloInstruction* reduce) override; + Status HandleCustomCall(HloInstruction* custom_call) override; + + // Unsupported HLOs, note some of them (such as BatchNorm*) are typically + // expanded in a semantic-preserving way into other HLOs by adding exanpsion + // HLO pass to the HLO optimization pass during compilation, which can then be + // handled by the evaluator. + Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override { + return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator."); + }; + Status HandleBatchNormInference( + HloInstruction* batch_norm_inference) override { + return Unimplemented( + "BatchNormInference HLO is unsupported by the evaluator."); + }; + Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { + return Unimplemented( + "BatchNormTraining HLO is unsupported by the evaluator."); + }; + Status HandleInfeed(HloInstruction* infeed) override { + return Unimplemented("Infeed HLO is unsupported by the evaluator."); + }; + Status HandleOutfeed(HloInstruction* outfeed) override { + return Unimplemented("Outfeed HLO is unsupported by the evaluator."); + }; + // Returns the already-evaluated literal result for the instruction. + // // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. + // + // Similarly, a Parameter instruction is considered evaluated and its literal + // is looked up in arg_literals. + // // Crash with log if the given instruction has not been evaluated previously. const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { if (hlo->IsConstant()) { return hlo->literal(); } + if (hlo->opcode() == HloOpcode::kParameter) { + return *arg_literals_.at(hlo->parameter_number()); + } auto it = evaluated_.find(hlo); CHECK(it != evaluated_.end()) << "could not find evaluated value for: " << hlo->ToString(); @@ -209,14 +286,23 @@ class HloEvaluator : public DfsHloVisitorWithDefault { } // Tracks the HLO instruction and its evaluated literal result. + // + // Parameters and constants aren't stored here, see implementation of + // GetEvaluatedLiteralFor. + // // TODO(b/35950897): have better memory management here to free instructions // that are no longer a parent for any other subsequent instruction in // post-orderring. + // // Must be cleared for each evaluation. - // Storing Literal in place require the container to have pointer stability so - // we cannot use flat_hash_map any more. + // + // Storing Literal in place requires the container to have pointer stability + // so we cannot use flat_hash_map any more. absl::node_hash_map evaluated_; + // Use fast path that uses eigen in the evaluator. + bool use_fast_path_ = false; + private: template static StatusOr ElementWiseUnaryOpImpl( @@ -245,11 +331,27 @@ class HloEvaluator : public DfsHloVisitorWithDefault { std::vector arg_literals_; // Max loop iterations to execute with no maximum if negative. - int64 max_loop_iterations_; + int64 max_loop_iterations_ = 0; + + // Module-level seed handle. + uint64 seed_ = 0; + // RNG engine. + std::minstd_rand0 engine_; + + // DynamicDimensionInference is used to evaluate GetDimensionSize, which + // returns the dynamic dimension size of its operand. + DynamicDimensionInference* dynamic_dimension_inference_ = nullptr; + + // Optional handler for custom_call ops. + std::function(HloInstruction* custom_call, + absl::Span operands)> + custom_call_handler_; TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; +std::unique_ptr> MatmulArray2D(const Array2D& lhs, + const Array2D& rhs); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 4eaaab20ea0add17d9b49b1b2b97991af0438dcc..383921fde22242b6ede95a6554f2348ab6fd4277 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -51,20 +51,18 @@ namespace { static std::array use_bf16_params{true, false}; -class HloEvaluatorTest : public ::testing::WithParamInterface, - public HloTestBase { - protected: - HloEvaluatorTest() : HloTestBase(), use_bfloat16_(GetParam()) { - evaluator_ = absl::make_unique(); - } +// Test fixture for the HloEvaluator. +// +// In bf16 mode, all f32 shapes are converted to bf16 before running. +class HloEvaluatorTest : public HloTestBase { + public: + HloEvaluatorTest() : use_bfloat16_(false) {} Literal Evaluate(absl::Span arg_literals = {}) { if (use_bfloat16_) { - // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. - auto type_converter = HloElementTypeConverter(F32, BF16); - type_converter.Run(m_.get()).ValueOrDie(); + HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie(); } - return evaluator_->Evaluate(*m_->entry_computation(), arg_literals) + return evaluator_.Evaluate(*m_->entry_computation(), arg_literals) .ConsumeValueOrDie(); } @@ -74,16 +72,12 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, Literal EvaluateWithModule( HloModule* module, absl::Span arg_literals = {}) { if (use_bfloat16_) { - // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. - auto type_converter = HloElementTypeConverter(F32, BF16); - type_converter.Run(module).ValueOrDie(); + HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie(); } - return evaluator_->Evaluate(*module->entry_computation(), arg_literals) + return evaluator_.Evaluate(*module->entry_computation(), arg_literals) .ConsumeValueOrDie(); } - std::unique_ptr evaluator_; - void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input, float aabs = 0) { HloComputation::Builder b(TestName()); @@ -117,16 +111,45 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } - bool use_bfloat16_; + void TestTernaryOp(HloOpcode opcode, Literal expected, Literal src0, + Literal src1, Literal src2) { + HloComputation::Builder b(TestName()); + auto operand0 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(src0))); + auto operand1 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(src1))); + auto operand2 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(src2))); + b.AddInstruction(HloInstruction::CreateTernary( + expected.shape(), opcode, operand0, operand1, operand2)); + m_->AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); + } + + protected: + explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {} + HloEvaluator evaluator_; + + const bool use_bfloat16_; std::unique_ptr m_ = CreateNewVerifiedModule(); }; -#define XLA_TYPED_TEST_P(test_case_name, test_name, test_type1) \ - TEST_P(test_case_name, test_name) +// Lets you write TEST_Ps that run twice, once with and once without bf16. +class HloEvaluatorBf16Test : public ::testing::WithParamInterface, + public HloEvaluatorTest { + protected: + HloEvaluatorBf16Test() : HloEvaluatorTest(/*use_bfloat16=*/GetParam()) {} +}; + +INSTANTIATE_TEST_SUITE_P(HloEvaluatorTest_Instantiation, HloEvaluatorBf16Test, + ::testing::ValuesIn(use_bf16_params)); // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. -TEST_P(HloEvaluatorTest, DoesClamp) { +TEST_P(HloEvaluatorBf16Test, DoesClamp) { auto low = LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); @@ -147,7 +170,34 @@ TEST_P(HloEvaluatorTest, DoesClamp) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { +// Verifies that clamping of int64 does not cause loss of precision +TEST_P(HloEvaluatorBf16Test, DoesClampInt64) { + auto ones = [](int bits) { return (int64{1} << bits) - 1; }; + + auto low = + LiteralUtil::CreateR2({{0, ones(54)}, {ones(54), ones(58)}}); + auto value = LiteralUtil::CreateR2({{0, ones(56)}, {0, ones(58)}}); + auto high = LiteralUtil::CreateR2( + {{ones(54), ones(55)}, {ones(56), ones(58)}}); + + Shape shape = low.shape(); + HloComputation::Builder b(TestName()); + auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); + auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); + b.AddInstruction( + HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); + m_->AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + auto expected = + LiteralUtil::CreateR2({{0, ones(55)}, {ones(54), ones(58)}}); + + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) { auto low = LiteralUtil::CreateR0(0.f); auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); auto high = LiteralUtil::CreateR0(1.f); @@ -170,7 +220,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. -TEST_P(HloEvaluatorTest, DoesSelect) { +TEST_P(HloEvaluatorBf16Test, DoesSelect) { auto pred = LiteralUtil::CreateR2({{true, false}, {false, true}}); auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); @@ -195,7 +245,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise addition with 2 operands. -TEST_P(HloEvaluatorTest, DoesAdd) { +TEST_F(HloEvaluatorTest, DoesAdd) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{3, 4}, {-96, 8}}); @@ -204,7 +254,7 @@ TEST_P(HloEvaluatorTest, DoesAdd) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise and with 2 operands. -TEST_P(HloEvaluatorTest, DoesAnd) { +TEST_P(HloEvaluatorBf16Test, DoesAnd) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{0, 0}, {4, 4}}); @@ -213,7 +263,7 @@ TEST_P(HloEvaluatorTest, DoesAnd) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. -TEST_P(HloEvaluatorTest, DoesOr) { +TEST_F(HloEvaluatorTest, DoesOr) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{3, 4}, {-100, 4}}); @@ -222,7 +272,7 @@ TEST_P(HloEvaluatorTest, DoesOr) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. -TEST_P(HloEvaluatorTest, DoesXor) { +TEST_F(HloEvaluatorTest, DoesXor) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{3, 4}, {-104, 0}}); @@ -231,7 +281,7 @@ TEST_P(HloEvaluatorTest, DoesXor) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise multiply with 2 operands. -TEST_P(HloEvaluatorTest, DoesMultiply) { +TEST_F(HloEvaluatorTest, DoesMultiply) { auto lhs = LiteralUtil::CreateR2({{-1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2( {{std::numeric_limits::min(), 4}, {4, 4}}); @@ -242,14 +292,28 @@ TEST_P(HloEvaluatorTest, DoesMultiply) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. -TEST_P(HloEvaluatorTest, DoesDivideInt64) { +TEST_F(HloEvaluatorTest, DoesDivideInt64) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{0, 0}, {-25, 1}}); TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } -TEST_P(HloEvaluatorTest, DoesDivideDouble) { + +TEST_F(HloEvaluatorTest, DoesClampS64) { + auto low = LiteralUtil::CreateR1( + {-8616761059752331528LL, 6780561065411491190LL, -8616761059752331528LL}); + auto value = LiteralUtil::CreateR1( + {-6780561065411491190LL, 6780561065411491180LL, 4241131823772864090LL}); + auto high = LiteralUtil::CreateR1( + {-6780561065411491180LL, 8616761059752331528LL, 3832151243857508051LL}); + auto expected = LiteralUtil::CreateR1( + {-6780561065411491190LL, 6780561065411491190LL, 3832151243857508051LL}); + TestTernaryOp(HloOpcode::kClamp, std::move(expected), std::move(low), + std::move(value), std::move(high)); +} + +TEST_P(HloEvaluatorBf16Test, DoesDivideDouble) { auto lhs = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); auto rhs = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); auto expected = @@ -260,41 +324,41 @@ TEST_P(HloEvaluatorTest, DoesDivideDouble) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. -TEST_P(HloEvaluatorTest, DoesAbsR2) { +TEST_F(HloEvaluatorTest, DoesAbsR2) { auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesAbsR0) { +TEST_P(HloEvaluatorBf16Test, DoesAbsR0) { auto operand = LiteralUtil::CreateR0(-1.0f); auto expected = LiteralUtil::CreateR0(1.0f); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesAbsR1WithZeroSize) { +TEST_P(HloEvaluatorBf16Test, DoesAbsR1WithZeroSize) { auto operand = LiteralUtil::CreateR1({}); auto expected = LiteralUtil::CreateR1({}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesNegateR2) { +TEST_F(HloEvaluatorTest, DoesNegateR2) { auto operand = LiteralUtil::CreateR2( {{0, std::numeric_limits::min()}, {-1, 4}}); auto expected = LiteralUtil::CreateR2( {{0, std::numeric_limits::min()}, {1, -4}}); TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesCosR2) { +TEST_P(HloEvaluatorBf16Test, DoesCosR2) { auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = LiteralUtil::CreateR2({{1, -1}, {-1, 1}}); TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } -TEST_P(HloEvaluatorTest, DoesSinR2) { +TEST_P(HloEvaluatorBf16Test, DoesSinR2) { auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } -TEST_P(HloEvaluatorTest, DoesNotR2) { +TEST_F(HloEvaluatorTest, DoesNotR2) { auto operand = LiteralUtil::CreateR2({{0, std::numeric_limits::min()}, {-1, std::numeric_limits::max()}}); @@ -303,9 +367,22 @@ TEST_P(HloEvaluatorTest, DoesNotR2) { {0, std::numeric_limits::min()}}); TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand)); } + +TEST_F(HloEvaluatorTest, DoesRealC128) { + auto x = LiteralUtil::CreateR1({{1, 0}, {-100, 4}}); + auto expected_real = LiteralUtil::CreateR1({1, -100}); + TestUnaryOp(HloOpcode::kReal, std::move(expected_real), std::move(x)); +} + +TEST_F(HloEvaluatorTest, DoesImagC128) { + auto x = LiteralUtil::CreateR1({{1, 0}, {-100, 4}}); + auto expected_imag = LiteralUtil::CreateR1({0, 4}); + TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x)); +} + // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. -TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { +TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); @@ -335,7 +412,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { } // Verifies Reshape operation is correctly evaluated. -TEST_P(HloEvaluatorTest, DoesReshape) { +TEST_F(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, @@ -361,7 +438,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { } // Verifies Broadcast operation is correctly evaluated. -TEST_P(HloEvaluatorTest, DoesBroadcast) { +TEST_F(HloEvaluatorTest, DoesBroadcast) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto output_literal = LiteralUtil::CreateR3( @@ -377,7 +454,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } -TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { +TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR0(111); auto output_literal = LiteralUtil::CreateR2( @@ -396,7 +473,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } -TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { +TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant( @@ -418,7 +495,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { +TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction( @@ -439,7 +516,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { +TEST_P(HloEvaluatorBf16Test, ConvertWithSameLayout) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); @@ -458,7 +535,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } -TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { +TEST_P(HloEvaluatorBf16Test, ConvertWithDifferentLayout) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR2WithLayout( @@ -491,7 +568,7 @@ PaddingConfig CreatePaddingConfig( return padding_config; } -TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { +TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto operand = LiteralUtil::CreateR2({{}, {}}); HloComputation::Builder b(TestName()); auto operand_instruction = @@ -516,7 +593,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { +TEST_P(HloEvaluatorBf16Test, Pad4DFloatArrayWithInteriorPadding) { HloComputation::Builder b(TestName()); Array4D input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); @@ -551,7 +628,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, NegativePadding2D) { +TEST_P(HloEvaluatorBf16Test, NegativePadding2D) { HloComputation::Builder b(TestName()); // input_array: @@ -593,7 +670,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250))); } -TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { +TEST_P(HloEvaluatorBf16Test, NegativeAndInteriorPadding2D) { HloComputation::Builder b(TestName()); // f32[4,3] { @@ -632,7 +709,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DotRank2AndRank1) { +TEST_P(HloEvaluatorBf16Test, DotRank2AndRank1) { HloComputation::Builder b(TestName()); // lhs: @@ -678,7 +755,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DotRank1AndRank2) { +TEST_P(HloEvaluatorBf16Test, DotRank1AndRank2) { HloComputation::Builder b(TestName()); // lhs: @@ -716,7 +793,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DotRank2AndRank2) { +TEST_P(HloEvaluatorBf16Test, DotRank2AndRank2) { HloComputation::Builder b(TestName()); // lhs: @@ -766,7 +843,51 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, SimpleConv1D) { +TEST_P(HloEvaluatorBf16Test, DotRank4AndRank4) { + HloComputation::Builder b(TestName()); + + auto lhs_array = absl::make_unique>(2, 2, 3, 1); + lhs_array->FillIota(1.0f); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(*lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + auto rhs_array = absl::make_unique>(2, 2, 3, 1); + rhs_array->FillIota(2.0f); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(*rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Shape shape = ShapeUtil::MakeShape(F32, {2, 1, 1}); + DotDimensionNumbers dot_dnums; + + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(2); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); + m_->AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + float expected_1 = 0; + for (float i = 1.0f; i < 7.0f; ++i) { + expected_1 += i * i + i; + } + float expected_2 = 0; + for (float i = 7.0f; i < 13.0f; ++i) { + expected_2 += i * i + i; + } + auto expected_array = Array3D({{{expected_1}}, {{expected_2}}}); + auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); + + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +TEST_P(HloEvaluatorBf16Test, SimpleConv1D) { HloComputation::Builder b(TestName()); Array3D lhs_array = {{{1, 2, 3}}}; @@ -804,7 +925,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -815,7 +936,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { +TEST_P(HloEvaluatorBf16Test, Simple4x4Conv2DWith2x2Kernel) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -859,7 +980,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -878,7 +999,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { +TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensionsReversed) { HloComputation::Builder b(TestName()); // clang-format off @@ -943,7 +1064,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -959,7 +1080,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { +TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensions) { HloComputation::Builder b(TestName()); // clang-format off @@ -1021,7 +1142,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1037,7 +1158,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { +TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithHighPadding) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -1081,7 +1202,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1101,7 +1222,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { +TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithLowAndHighPadding) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -1145,7 +1266,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1166,7 +1287,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, +TEST_P(HloEvaluatorBf16Test, DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) { HloComputation::Builder b(TestName()); @@ -1217,7 +1338,7 @@ TEST_P(HloEvaluatorTest, Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1239,7 +1360,7 @@ TEST_P(HloEvaluatorTest, EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { +TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) { HloComputation::Builder b(TestName()); std::vector input_dims = {1, 2, 2, 4}; std::vector filter_dims = {2, 2, 2, 8}; @@ -1288,7 +1409,8 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, - /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/2, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1374,7 +1496,7 @@ void BM_ReducePrecisely(int num_iters) { BENCHMARK(BM_ReducePrecisely); -TEST_P(HloEvaluatorTest, ReduceAdd) { +TEST_P(HloEvaluatorBf16Test, ReduceAdd) { HloComputation::Builder b(TestName()); // arg: @@ -1416,7 +1538,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowMax) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) { HloComputation::Builder b(TestName()); // arg: @@ -1467,7 +1589,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxWindowDilation) { HloComputation::Builder b(TestName()); // arg: @@ -1519,7 +1641,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowAdd) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd) { HloComputation::Builder b(TestName()); // arg: @@ -1576,7 +1698,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) { HloComputation::Builder b(TestName()); // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. @@ -1639,7 +1761,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result)); } -TEST_P(HloEvaluatorTest, StridedSlice) { +TEST_P(HloEvaluatorBf16Test, StridedSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1673,7 +1795,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DynamicSlice) { +TEST_P(HloEvaluatorBf16Test, DynamicSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1689,12 +1811,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); - auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); + auto zero = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto one = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); - b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, - start_indices, {2, 3})); + b.AddInstruction( + HloInstruction::CreateDynamicSlice(shape, operand, {zero, one}, {2, 3})); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1709,7 +1833,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { // Verifies that the HloEvaluator's implementation goes along with existing // backends' behavior, although this is not required by the spec. -TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { +TEST_P(HloEvaluatorBf16Test, DynamicSliceModSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1725,12 +1849,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); - auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2, 1}))); + auto two = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); + auto one = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); - b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, - start_indices, {2, 3})); + b.AddInstruction( + HloInstruction::CreateDynamicSlice(shape, operand, {two, one}, {2, 3})); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1743,7 +1869,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { +TEST_P(HloEvaluatorBf16Test, DynamicSliceUpdate) { HloComputation::Builder b(TestName()); // arg: @@ -1759,15 +1885,17 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); - auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); + auto zero = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto one = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto update = b.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{-2.0, -3.0}, {-6.0, -7.0}}))); Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - shape, operand, update, start_indices)); + shape, operand, update, {zero, one})); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1780,7 +1908,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, SetAndGetTuples) { +TEST_P(HloEvaluatorBf16Test, SetAndGetTuples) { HloComputation::Builder b(TestName()); // arg: @@ -1816,7 +1944,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { +TEST_P(HloEvaluatorBf16Test, SetAndGetNestedTuples) { HloComputation::Builder b(TestName()); // arg: @@ -1855,7 +1983,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Reverse) { +TEST_P(HloEvaluatorBf16Test, Reverse) { HloComputation::Builder b(TestName()); // Input shape is float[4x3x2x1]. @@ -1908,7 +2036,7 @@ TEST_P(HloEvaluatorTest, Reverse) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { +TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutions) { HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -1932,7 +2060,7 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { // Check that EvaluateWithSubstitutions works if one of the operands to the op // we're evaluating is a constant. -TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { +TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutionsWithConstantOperand) { HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -1955,7 +2083,7 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { const char* hlo_text = R"( HloModule TensorFlowGatherV1 @@ -1979,7 +2107,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { const char* hlo_text = R"( HloModule TensorFlowGatherV2 @@ -2003,7 +2131,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { const char* hlo_text = R"( HloModule TensorFlowGatherMultipleBatchDims @@ -2028,7 +2156,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { const char* hlo_text = R"( HloModule TensorFlowGatherNd @@ -2054,7 +2182,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) { const char* hlo_text = R"( HloModule TensorFlowGatherNd @@ -2081,7 +2209,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { +TEST_F(HloEvaluatorTest, EvaluateGather_DynamicSlice) { const char* hlo_text = R"( HloModule DynamicSlice @@ -2104,7 +2232,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { +TEST_F(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { const char* hlo_text = R"( HloModule BatchDynamicSlice @@ -2128,7 +2256,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { +TEST_F(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { const char* hlo_text = R"( HloModule TensorFlowGatherV1 @@ -2150,7 +2278,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { +TEST_F(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { const string hlo_text = R"( HloModule GatherXd @@ -2175,7 +2303,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { const char* hlo_text = R"( HloModule TensorFlowScatterV1 @@ -2206,7 +2334,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { const char* hlo_text = R"( HloModule TensorFlowScatterV2 @@ -2238,7 +2366,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2270,7 +2398,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2302,7 +2430,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) { +TEST_P(HloEvaluatorBf16Test, EvaluateScatter_TensorFlowScatter_F32) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2336,7 +2464,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01})); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2368,7 +2496,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { const char* hlo_text = R"( HloModule TensorFlowScatterMultipleBatchDims @@ -2401,7 +2529,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { const char* hlo_text = R"( HloModule TensorFlowScatterNd @@ -2437,7 +2565,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) { const char* hlo_text = R"( HloModule TensorFlowScatterNdNonDefaultIndexVectorDim @@ -2474,7 +2602,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { +TEST_F(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { const char* hlo_text = R"( HloModule DynamicUpdateSlice @@ -2506,7 +2634,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { +TEST_F(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { const char* hlo_text = R"( HloModule BatchDynamicUpdateSlice @@ -2538,7 +2666,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { +TEST_F(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { const char* hlo_text = R"( HloModule TensorFlowScatter_ZeroDimBounds @@ -2567,7 +2695,7 @@ ENTRY main { operand, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { +TEST_F(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { const string hlo_text = R"( HloModule Scatter_NoUpdateWindowDims @@ -2600,7 +2728,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_NegativeIndices) { +TEST_F(HloEvaluatorTest, EvaluateScatter_NegativeIndices) { const char* hlo_text = R"( HloModule TensorFlowScatter_NegativeIndices @@ -2635,7 +2763,7 @@ ENTRY main { {&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_OobIndices) { +TEST_F(HloEvaluatorTest, EvaluateScatter_OobIndices) { const string hlo_text = R"( HloModule BatchDynamicUpdateSlice @@ -2671,7 +2799,7 @@ ENTRY main { {&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) { +TEST_F(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) { const char* hlo_text = R"( HloModule TensorFlowScatterNd_OobUpdateWindow @@ -2710,7 +2838,7 @@ ENTRY main { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise comparison with 2 bfloat16 operands. -TEST_P(HloEvaluatorTest, DoesCompareBF16) { +TEST_F(HloEvaluatorTest, DoesCompareBF16) { // lhs >= rhs auto lhs = LiteralUtil::CreateR2( {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)}, @@ -2724,7 +2852,7 @@ TEST_P(HloEvaluatorTest, DoesCompareBF16) { std::move(rhs)); } -TEST_P(HloEvaluatorTest, Bf16Reduction) { +TEST_P(HloEvaluatorBf16Test, Bf16Reduction) { const string hlo_text = R"( HloModule Bf16Reduction @@ -2748,7 +2876,7 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg}))); } -TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) { +TEST_P(HloEvaluatorBf16Test, SliceWithDifferentLayout) { // Regression test for b/114735354. const string hlo_text = R"( HloModule SliceWithDifferentLayout @@ -2767,7 +2895,7 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); } -TEST_P(HloEvaluatorTest, Bitcast) { +TEST_P(HloEvaluatorBf16Test, Bitcast) { // Regression test for b/114735354. constexpr absl::string_view hlo_text_base = R"( HloModule Bitcast @@ -2794,8 +2922,295 @@ ENTRY main { } } -INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, - ::testing::ValuesIn(use_bf16_params)); +// Check that s32 under/overflow doesn't trigger a ubsan failure. +TEST_F(HloEvaluatorTest, Int32Overflow) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + c1 = s32[] constant(1073741824) // 2^30 + sum = s32[] add(c1, c1) // 2^31, i.e. INT_MIN + + c2 = s32[] constant(-2147483648) // -2^31 + sub = s32[] subtract(c2, c1) // -2^31 - 2^30, underflows + + mul = s32[] multiply(c1, c1) + ROOT tuple = (s32[], s32[], s32[]) tuple(sum, sub, mul) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + std::vector actual = Evaluate({}).DecomposeTuple(); + ASSERT_EQ(actual.size(), 3); + + uint32 pow30 = uint32{1} << 30; + uint32 pow31 = uint32{1} << 31; + EXPECT_EQ(actual[0].GetFirstElement(), static_cast(pow31)); + EXPECT_EQ(actual[1].GetFirstElement(), + static_cast(-(pow31 + pow30))); + EXPECT_EQ(actual[2].GetFirstElement(), + static_cast(pow31 * pow31)); +} + +TEST_F(HloEvaluatorTest, GetDimensionSize) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + size = u32[] parameter(0) + + data = s32[4] parameter(1) + + sum = s32[4] add(data, data) + + ROOT dynamic_size = u32[] get-dimension-size(sum), dimensions={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + + // Set up dynamic parameter binding. + TF_CHECK_OK(m_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{0, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + TF_ASSERT_OK_AND_ASSIGN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(m_.get())); + + evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference); + Literal size_arg = LiteralUtil::CreateR0(3); + Literal data_arg = LiteralUtil::CreateR1({1, 2, 3, 4}); + + Literal actual = Evaluate({&size_arg, &data_arg}); + + EXPECT_EQ(actual.GetFirstElement(), static_cast(3)); +} + +// Check that we get a useful error if we pass inputs of the wrong shape. +TEST_F(HloEvaluatorTest, EvaluateWithWrongInputShapes) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + p0 = s32[1] parameter(0) + ROOT sum = s32[1] add(p0, p0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + Literal input_wrong_shape = LiteralUtil::CreateR1({0, 1}); + + EXPECT_EQ(HloEvaluator() + .Evaluate(*m_, {&input_wrong_shape}) + .status() + .error_message(), + "Shape mismatch at parameter 0. Computation expected s32[1]{0}, " + "but arg was s32[2]."); + EXPECT_EQ(HloEvaluator() + .Evaluate(*m_->entry_computation(), {&input_wrong_shape}) + .status() + .error_message(), + "Shape mismatch at parameter 0. Computation expected s32[1]{0}, " + "but arg was s32[2]."); +} + +// Check that we get a useful error if we pass too many or too few inputs. +TEST_F(HloEvaluatorTest, EvaluateWithWrongNumberOfInputs) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + p0 = s32[1] parameter(0) + ROOT sum = s32[1] add(p0, p0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + Literal input = LiteralUtil::CreateR1({0}); + + EXPECT_EQ( + HloEvaluator().Evaluate(*m_, {&input, &input}).status().error_message(), + "Expected 1 argument, but got 2."); + EXPECT_EQ(HloEvaluator() + .Evaluate(*m_->entry_computation(), {&input, &input}) + .status() + .error_message(), + "Expected 1 argument, but got 2."); +} + +TEST_F(HloEvaluatorTest, PreserveFusionInputLayout) { + constexpr absl::string_view hlo_text = R"( + HloModule FusionInputLayout + + fused_computation { + param_0 = f32[20,20]{0,1} parameter(0) + ROOT bitcast = f32[20,20]{1,0} bitcast(param_0) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{0,1} parameter(0) + ROOT fusion = f32[20,20]{1,0} fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual = Evaluate({&args[0]}); + EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); +} + +TEST_F(HloEvaluatorTest, PreserveFusionOutputLayout) { + constexpr absl::string_view hlo_text = R"( + HloModule FusionOutputLayout + + fused_computation { + param_0 = f32[20,20]{1,0} parameter(0) + ROOT bitcast = f32[20,20]{0,1} bitcast(param_0) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{1,0} parameter(0) + ROOT fusion = f32[20,20]{0,1} fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual = Evaluate({&args[0]}); + EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); +} + +TEST_F(HloEvaluatorTest, PreserveMOFusionOutputLayout) { + constexpr absl::string_view hlo_text = R"( + HloModule MOFusionOutputLayout + + fused_computation { + param_0 = f32[20,20]{1,0} parameter(0) + bitcast = f32[20,20]{0,1} bitcast(param_0) + ROOT tuple = (f32[20,20]{0,1}) tuple(bitcast) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{1,0} parameter(0) + ROOT fusion = (f32[20,20]{0,1}) fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual_tuple = Evaluate({&args[0]}); + std::vector actual_literals = actual_tuple.DecomposeTuple(); + EXPECT_TRUE( + absl::c_equal(args[0].data(), actual_literals[0].data())); +} + +// Tests that custom_calls fail to evaluate when no handler is specified. +TEST_F(HloEvaluatorTest, EvaluateCustomCall_NoHandler) { + constexpr absl::string_view hlo_text = R"( + HloModule EvaluateCustomCall_NoHandler + ENTRY kernel_entry { + parameter.0 = u32[2,2]{1,0} parameter(0) + ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0), + custom_call_target="_my_custom_call" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + EXPECT_EQ(HloEvaluator().Evaluate(*m_, {&args[0]}).status().code(), + ::tensorflow::error::UNIMPLEMENTED); +} + +// Tests when a custom_call handler returns an error. +TEST_F(HloEvaluatorTest, EvaluateCustomCall_HandlerError) { + constexpr absl::string_view hlo_text = R"( + HloModule EvaluateCustomCall_HandlerError + ENTRY kernel_entry { + parameter.0 = u32[2,2]{1,0} parameter(0) + ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0), + custom_call_target="_my_custom_call" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + HloEvaluator evaluator; + evaluator.set_custom_call_handler( + [](HloInstruction* custom_call, absl::Span operands) { + return InternalError("Test error"); + }); + EXPECT_EQ(evaluator.Evaluate(*m_, {&args[0]}).status().code(), + ::tensorflow::error::INTERNAL); +} + +// Tests the custom_call handler on calls with many inputs. +// We sum the operands so that we can verify the operand and output literals +// are properly mapped for access. +TEST_F(HloEvaluatorTest, EvaluateCustomCall_ManyInputs) { + constexpr absl::string_view hlo_text = R"( + HloModule EvaluateCustomCall_ManyInputs + ENTRY kernel_entry { + parameter.0 = u32[1]{0} parameter(0) + parameter.1 = u32[1]{0} parameter(1) + ROOT test_root = u32[1]{0} custom-call(parameter.0, parameter.1), + custom_call_target="_my_custom_call" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + HloEvaluator evaluator; + evaluator.set_custom_call_handler( + [](HloInstruction* custom_call, absl::Span operands) { + EXPECT_EQ(HloOpcode::kCustomCall, custom_call->opcode()); + EXPECT_EQ("_my_custom_call", custom_call->custom_call_target()); + EXPECT_EQ(2, custom_call->operand_count()); + EXPECT_EQ(2, operands.size()); + auto output = Literal::CreateFromShape(custom_call->shape()); + auto operand0_data = operands[0]->data(); + auto operand1_data = operands[1]->data(); + auto output_data = output.data(); + output_data[0] = operand0_data[0] + operand1_data[0]; + return output; + }); + TF_ASSERT_OK_AND_ASSIGN( + Literal actual_literal, + evaluator.Evaluate(*m_->entry_computation(), {&args[0], &args[1]})); + auto arg0_data = args[0].data(); + auto arg1_data = args[1].data(); + std::vector expected_data = {arg0_data[0] + arg1_data[0]}; + EXPECT_TRUE(absl::c_equal(expected_data, actual_literal.data())); +} + +TEST_F(HloEvaluatorTest, IsFiniteF16) { + constexpr absl::string_view hlo_text = R"( + HloModule test + + ENTRY IsFiniteTest { + c = f16[6] constant({nan, 7, nan, -1, inf, -inf}) + ROOT is-finite = pred[6] is-finite(c) + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN( + Literal actual_literal, + HloEvaluator().Evaluate(*m_->entry_computation(), {})); + EXPECT_THAT(actual_literal.data(), + ::testing::ElementsAre(false, true, false, true, false, false)); +} + +TEST_F(HloEvaluatorTest, IsFiniteBf16) { + constexpr absl::string_view hlo_text = R"( + HloModule test + + ENTRY IsFiniteTest { + c = bf16[6] constant({nan, 7, nan, -1, inf, -inf}) + ROOT is-finite = pred[6] is-finite(c) + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN( + Literal actual_literal, + HloEvaluator().Evaluate(*m_->entry_computation(), {})); + EXPECT_THAT(actual_literal.data(), + ::testing::ElementsAre(false, true, false, true, false, false)); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index b87fc3e34012e75ee07bff6c1e113dce404f83cb..2d8a578985e8f603d4056bee8619725095ebc7bb 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -17,12 +17,15 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #include +#include #include "absl/algorithm/container.h" #include "absl/base/casts.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/meta/type_traits.h" #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" @@ -38,48 +41,27 @@ namespace xla { // Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is // a "private" header that's not exposed outside of hlo_evaluator.cc. template -using is_complex_t = std::is_same; -template -using is_complex64_t = std::is_same; - -// It's UB to use std::sort with std::less, because of NaNs. Define -// "safe" less functions which are actually strict weak orders. -NaN and NaN -// should appear at the beginning and end of the ordering, and -0.0 should -// appear before 0.0. -template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> -bool SafeLess(const NativeT& a, const NativeT& b) { - return a < b; -} +using is_complex_t = + absl::disjunction, std::is_same>; -template ::value>::type* = nullptr> -bool SafeLess(const NativeT& a, const NativeT& b) { - bool lhs_is_negative = std::signbit(a); - bool rhs_is_negative = std::signbit(b); - // If the signs are different, we can just compare the signs. - if (lhs_is_negative != rhs_is_negative) { - return lhs_is_negative && !rhs_is_negative; - } - bool lhs_nan = std::isnan(a); - bool rhs_nan = std::isnan(b); - // Exactly one number is nan? - if (lhs_nan != rhs_nan) { - if (lhs_nan) { - return lhs_is_negative; - } - return !rhs_is_negative; - } - return a < b; +// ToArithmeticSafeType(T t): +// - converts `t` to the bitwise-equivalent `unsigned T` if T is a signed +// integer, and +// - otherwise returns `t` unchanged. +// +// It's UB in C++ to under/overflow a signed integer, so we wrap all arithmetic +// in this type to force 2's complement behavior. +template ::value && + std::is_signed::value>::type* = nullptr> +typename std::make_unsigned::type ToArithmeticSafeType(T t) { + return static_cast::type>(t); } - -template ::value || - std::is_same::value>::type* = nullptr> -bool SafeLess(const NativeT& a, const NativeT& b) { - return SafeLess(static_cast(a), static_cast(b)); +template ::value || + !std::is_signed::value>::type* = nullptr> +T ToArithmeticSafeType(T t) { + return std::move(t); } // Templated DfsHloVisitor for use by HloEvaluator. @@ -105,6 +87,12 @@ bool SafeLess(const NativeT& a, const NativeT& b) { template class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { private: + Status UnsupportedTypeError(HloInstruction* instruction) { + return InvalidArgument( + "Unsupported type for %s: %s", HloOpcodeString(instruction->opcode()), + PrimitiveType_Name(instruction->shape().element_type())); + } + // Get the value in the given literal static_cast as a double. template < typename NativeT, @@ -185,7 +173,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, - typename std::enable_if::value>::type* = nullptr> + typename std::enable_if::value>::type* = nullptr> Status HandleAbs(HloInstruction* abs) { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(abs->operand(0)); @@ -204,6 +192,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // specifying the ElementwiseT explicitly as C64 is needed below. if (abs->operand(0)->shape().element_type() == C64) { return HandleAbs(abs); + } else if (abs->operand(0)->shape().element_type() == C128) { + return HandleAbs(abs); } return HandleAbs(abs); } @@ -224,7 +214,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRound(HloInstruction* round) { - return InvalidArgument("Unsupported type for Round"); + return UnsupportedTypeError(round); } Status HandleRound(HloInstruction* round) override { @@ -246,7 +236,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleCeil(HloInstruction* ceil) { - return InvalidArgument("Unsupported type for Ceil"); + return UnsupportedTypeError(ceil); } Status HandleCeil(HloInstruction* ceil) override { @@ -297,8 +287,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleExpm1(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Expm1"); + Status HandleExpm1(HloInstruction* expm1) { + return UnsupportedTypeError(expm1); } Status HandleExpm1(HloInstruction* floor) override { @@ -321,7 +311,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleFloor(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Floor"); + return UnsupportedTypeError(floor); } Status HandleFloor(HloInstruction* floor) override { @@ -339,10 +329,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleLog1p(HloInstruction* expm1) { + Status HandleLog1p(HloInstruction* log1p) { TF_ASSIGN_OR_RETURN( - parent_->evaluated_[expm1], - ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) { + parent_->evaluated_[log1p], + ElementWiseUnaryOp(log1p, [](ElementwiseT elem_operand) { return std::log1p(elem_operand); })); return Status::OK(); @@ -351,12 +341,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleLog1p(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Log1p"); + Status HandleLog1p(HloInstruction* log1p) { + return UnsupportedTypeError(log1p); } - Status HandleLog1p(HloInstruction* floor) override { - return HandleLog1p(floor); + Status HandleLog1p(HloInstruction* log1p) override { + return HandleLog1p(log1p); } template ::value>::type* = nullptr> Status HandleNot(HloInstruction* not_) { - return InvalidArgument("Unsupported type for Not"); + return UnsupportedTypeError(not_); } Status HandleNot(HloInstruction* not_) override { @@ -433,9 +423,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleNegate(negate); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> Status HandleSign(HloInstruction* sign) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { @@ -445,6 +435,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value || + std::is_same::value || + std::is_floating_point::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + return std::isnan(elem_operand) + ? elem_operand + : std::copysign( + elem_operand != ElementwiseT(0), + elem_operand); + })); + return Status::OK(); + } + template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -476,7 +483,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleAtan2(HloInstruction* atan2) { - return InvalidArgument("Unsupported type for Atan2"); + return UnsupportedTypeError(atan2); } Status HandleAtan2(HloInstruction* atan2) override { @@ -491,47 +498,25 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - template ::value && - !std::is_floating_point::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { - using type = typename std::make_unsigned::type; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return NativeT(type(lhs_elem) * type(rhs_elem)); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - std::is_floating_point::value || - is_complex_t::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { + Status HandleMultiply(HloInstruction* multiply) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem * rhs_elem; - })); + ElementWiseBinaryOp( + multiply, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return ElementwiseT(ToArithmeticSafeType(lhs_elem) * + ToArithmeticSafeType(rhs_elem)); + })); return Status::OK(); } - Status HandleMultiply(HloInstruction* multiply) override { - return HandleMultiply(multiply); - } - Status HandleSubtract(HloInstruction* subtract) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[subtract], - ElementWiseBinaryOp(subtract, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem - rhs_elem; - })); + ElementWiseBinaryOp( + subtract, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return ElementwiseT(ToArithmeticSafeType(lhs_elem) - + ToArithmeticSafeType(rhs_elem)); + })); return Status::OK(); } @@ -539,7 +524,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem + rhs_elem; + return ElementwiseT(ToArithmeticSafeType(lhs_elem) + + ToArithmeticSafeType(rhs_elem)); })); return Status::OK(); } @@ -624,7 +610,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleMaximum(HloInstruction* maximum) { - return InvalidArgument("Unsupported type for Maximum"); + return UnsupportedTypeError(maximum); } Status HandleMaximum(HloInstruction* maximum) override { @@ -659,7 +645,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleMinimum(HloInstruction* minimum) { - return InvalidArgument("Unsupported type for Minimum"); + return UnsupportedTypeError(minimum); } Status HandleMinimum(HloInstruction* minimum) override { @@ -667,14 +653,34 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandlePower(HloInstruction* power) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], - ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::pow(lhs_el, rhs_el); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[power], + ElementWiseBinaryOp( + power, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el == ElementwiseT(0) && rhs_el == ElementwiseT(0) + ? static_cast(1) + : std::pow(lhs_el, rhs_el); + })); + return Status::OK(); + } + + Status HandleSqrt(HloInstruction* sqrt) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sqrt], + ElementWiseUnaryOp(sqrt, [](ElementwiseT elem_operand) { + return std::sqrt(elem_operand); })); return Status::OK(); } + Status HandleRsqrt(HloInstruction* rsqrt) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[rsqrt], + ElementWiseUnaryOp(rsqrt, [](ElementwiseT elem_operand) { + return static_cast(1) / std::sqrt(elem_operand); + })); + return Status::OK(); + } + template ::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { @@ -724,7 +730,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { - return InvalidArgument("Unsupported type for Remainder"); + return UnsupportedTypeError(remainder); } Status HandleRemainder(HloInstruction* remainder) override { @@ -746,14 +752,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); + return UnsupportedTypeError(and_); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); + return UnsupportedTypeError(and_); } Status HandleAnd(HloInstruction* and_) override { @@ -775,7 +781,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleOr(HloInstruction* or_) { - return InvalidArgument("Unsupported type for Or"); + return UnsupportedTypeError(or_); } template < @@ -804,14 +810,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleXor(HloInstruction* xor_) { - return InvalidArgument("Unsupported type for Xor"); + return UnsupportedTypeError(xor_); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleXor(HloInstruction* xor_) { - return InvalidArgument("Unsupported type for Xor"); + return UnsupportedTypeError(xor_); } Status HandleXor(HloInstruction* xor_) override { @@ -836,8 +842,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftLeft(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftLeft"); + Status HandleShiftLeft(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftLeft(HloInstruction* shl) override { @@ -866,8 +872,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftRightArithmetic(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + Status HandleShiftRightArithmetic(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftRightArithmetic(HloInstruction* shra) override { @@ -897,21 +903,45 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftRightLogical(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightLogical"); + Status HandleShiftRightLogical(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftRightLogical(HloInstruction* shrl) override { return HandleShiftRightLogical(shrl); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + // Special case for integral type due to MSVC's std::isnan being unable to + // handle integral type. + template ::value && + std::is_integral::value>::type* = + nullptr> + Status HandleClamp(HloInstruction* clamp) { + std::function + clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { + return static_cast( + std::min(high, std::max(value, low))); + }; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[clamp], + ElementwiseTernaryOp(clamp, + std::move(ConvertTernaryFunction(clamp_op)))); + return Status::OK(); + } + + template ::value && + !std::is_integral::value>::type* = + nullptr> Status HandleClamp(HloInstruction* clamp) { std::function clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { - return std::fmin(high, std::fmax(value, low)); + if (std::isnan(low) || std::isnan(high)) { + return static_cast(NAN); + } + return static_cast( + std::min(high, std::max(value, low))); }; TF_ASSIGN_OR_RETURN( parent_->evaluated_[clamp], @@ -923,8 +953,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleClamp(HloInstruction*) { - return InvalidArgument("Unsupported type for Clamp"); + Status HandleClamp(HloInstruction* clamp) { + return UnsupportedTypeError(clamp); } Status HandleClamp(HloInstruction* clamp) override { @@ -933,7 +963,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override { CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); - CHECK(ShapeUtil::IsArray(select->shape())); + CHECK(select->shape().IsArray()); std::function select_op = [](bool pred, ReturnT on_true, ReturnT on_false) { if (pred) { @@ -986,8 +1016,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); - CHECK(ShapeUtil::IsArray(lhs_shape)); - CHECK(ShapeUtil::IsArray(rhs_shape)); + CHECK(lhs_shape.IsArray()); + CHECK(rhs_shape.IsArray()); CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); @@ -998,16 +1028,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_GE(num_spatial_dims, 0); CHECK_EQ(window.dimensions_size(), num_spatial_dims); - const auto lhs_rank = ShapeUtil::Rank(lhs_shape); - const auto rhs_rank = ShapeUtil::Rank(rhs_shape); + const auto lhs_rank = lhs_shape.rank(); + const auto rhs_rank = rhs_shape.rank(); CHECK_EQ(num_spatial_dims + 2, lhs_rank); CHECK_EQ(num_spatial_dims + 2, rhs_rank); - TF_ASSIGN_OR_RETURN( - auto inferred_return_shape, - ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, conv->feature_group_count(), window, dnums)); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, conv->feature_group_count(), + conv->batch_group_count(), window, dnums)); CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1030,12 +1060,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto lhs_literal_data = lhs_literal.data(); auto rhs_literal_data = rhs_literal.data(); - int64 feature_group_count = conv->feature_group_count(); + const int64 feature_group_count = conv->feature_group_count(); + const int64 batch_group_count = conv->batch_group_count(); auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data, - feature_group_count](const absl::Span out_index) { + rhs_literal_data, feature_group_count, + batch_group_count](const absl::Span out_index) { // Dimension number applicable for input (lhs). const int64 input_batch_dim = dnums.input_batch_dimension(); const int64 input_z_dim = dnums.input_feature_dimension(); @@ -1048,6 +1079,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 input_z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + + const int64 input_batch_size = + ShapeUtil::GetDimension(lhs_shape, input_batch_dim); + + const int64 batch_group_size = input_batch_size / batch_group_count; + // The size of an input feature group. const int64 input_feature_group_size = input_z_size / feature_group_count; @@ -1063,11 +1100,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 feature_group_index = out_index[output_z_dim] / output_feature_group_size; + const int64 batch_group_index = out_index[output_z_dim]; + ElementwiseT result_val = static_cast(0); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), 0); // Convolve input feature with kernel. + // The mechanism indexes into the correct LHS (input) and RHS (kernel) + // locations and accumulates multiplications for a given output index. do { // Find corresponding spatial dimension index for input (lhs). int64 lhs_linear_spatial_index = 0; @@ -1120,11 +1161,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { feature_group_index * input_feature_group_size + rhs_iz; int64 lhs_linear_index = lhs_linear_spatial_index; + lhs_linear_index += out_index[output_batch_dim] * lhs_dim_multipliers[input_batch_dim]; + + // We are scraping only the diagonal elements in the resultant + // convolution output when batch_group_count is greater than 1, + // where 1 is the default. No scraping is done in that case. + // This approach works out automatically for 'groups' in batches + // with group_size > 1, because we already descend down the batch + // dimension for the 'output_batch_dim' above. + lhs_linear_index += + ((batch_group_index * batch_group_size) % input_batch_size) * + lhs_dim_multipliers[input_batch_dim]; + lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; int64 rhs_linear_index = rhs_linear_spatial_index; + rhs_linear_index += out_index[output_z_dim] * rhs_dim_multipliers[kernel_output_z_dim]; rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim]; @@ -1148,23 +1202,31 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandleDot(HloInstruction* dot) override { - auto lhs = dot->operand(0); - auto rhs = dot->operand(1); - CHECK(ShapeUtil::IsArray(dot->shape())); - CHECK(ShapeUtil::IsArray(lhs->shape())); - CHECK(ShapeUtil::IsArray(rhs->shape())); + if (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() == 1 && + parent_->use_fast_path_) { + return HandleDot(dot); + } + return HandleDotSlowPath(dot); + } + + template ::value>::type* = nullptr> + Status HandleDot(HloInstruction* dot) { + const HloInstruction* lhs = dot->operand(0); + const HloInstruction* rhs = dot->operand(1); + CHECK(dot->shape().IsArray()); + CHECK(lhs->shape().IsArray()); + CHECK(rhs->shape().IsArray()); const auto& dnums = dot->dot_dimension_numbers(); - const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); - const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); + const int64 lhs_rank = lhs->shape().rank(); + const int64 rhs_rank = rhs->shape().rank(); CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); // There must be 1 and only 1 Contracting dimension for lhs and rhs. - CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); - CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); // Contracted dimension sizes must be the same. @@ -1174,8 +1236,56 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << lhs->shape().dimensions(lhs_contracting_dimension) << " rhs contracted dimension: " << rhs->shape().dimensions(rhs_contracting_dimension); - const int64 contracted_dimension_size = - lhs->shape().dimensions(lhs_contracting_dimension); + + // The fast path is for a simple rank 2 dot with default layout operands. + if (lhs_rank == 2 && rhs_rank == 2 && lhs_contracting_dimension == 1 && + rhs_contracting_dimension == 0 && + LayoutUtil::Equal(lhs->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2()) && + LayoutUtil::Equal(rhs->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2()) && + LayoutUtil::Equal(dot->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2())) { + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + const int64 contracted_dimension_size = + lhs->shape().dimensions(lhs_contracting_dimension); + Array2D lhs_array(lhs->shape().dimensions(0), + contracted_dimension_size); + lhs_array.SetValues(lhs_literal.data()); + Array2D rhs_array(contracted_dimension_size, + rhs->shape().dimensions(1)); + rhs_array.SetValues(rhs_literal.data()); + std::unique_ptr> result_array = + HloEvaluator::MatmulArray2D(lhs_array, rhs_array); + Literal result(dot->shape()); + result.PopulateR2FromArray2D(*result_array); + parent_->evaluated_[dot] = std::move(result); + return Status::OK(); + } + return HandleDotSlowPath(dot); + } + + template ::value>::type* = nullptr> + Status HandleDot(HloInstruction* dot) { + return HandleDotSlowPath(dot); + } + + Status HandleDotSlowPath(HloInstruction* dot) { + auto lhs = dot->operand(0); + auto rhs = dot->operand(1); + CHECK(dot->shape().IsArray()); + CHECK(lhs->shape().IsArray()); + CHECK(rhs->shape().IsArray()); + + const auto& dnums = dot->dot_dimension_numbers(); + + const auto lhs_rank = lhs->shape().rank(); + const auto rhs_rank = rhs->shape().rank(); + + CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); @@ -1190,7 +1300,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // in lhs_index or rhs_index where the i'th result index should go. absl::InlinedVector, kInlineRank> result_index_locations; - result_index_locations.reserve(lhs_rank + rhs_rank - 2); + result_index_locations.reserve( + (lhs_rank - dnums.lhs_contracting_dimensions_size()) + + (rhs_rank - dnums.rhs_contracting_dimensions_size())); // The first components in the output shape are the LHS and RHS batch // dimensions: @@ -1202,18 +1314,32 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Then we have the LHS and RHS non-contracting dimensions, if any: for (int64 i = 0; i < lhs_rank; i++) { - if (i != lhs_contracting_dimension && + if (!absl::c_linear_search(dnums.lhs_contracting_dimensions(), i) && !absl::c_linear_search(dnums.lhs_batch_dimensions(), i)) { result_index_locations.push_back({&lhs_index[i], nullptr}); } } for (int64 i = 0; i < rhs_rank; i++) { - if (i != rhs_contracting_dimension && + if (!absl::c_linear_search(dnums.rhs_contracting_dimensions(), i) && !absl::c_linear_search(dnums.rhs_batch_dimensions(), i)) { result_index_locations.push_back({&rhs_index[i], nullptr}); } } + absl::InlinedVector accumulate_index_sizes; + accumulate_index_sizes.reserve(dnums.lhs_contracting_dimensions_size()); + absl::InlinedVector, kInlineRank> + accumulate_index_locations; + accumulate_index_locations.reserve(dnums.lhs_contracting_dimensions_size()); + for (int64 i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { + const int64 lhs_dnum = dnums.lhs_contracting_dimensions(i); + const int64 rhs_dnum = dnums.rhs_contracting_dimensions(i); + accumulate_index_locations.push_back( + {&lhs_index[lhs_dnum], &rhs_index[rhs_dnum]}); + const int64 dim_size = lhs->shape().dimensions(lhs_dnum); + accumulate_index_sizes.push_back(dim_size); + } + const int64 total_contraction_size = Product(accumulate_index_sizes); Literal result(dot->shape()); TF_RETURN_IF_ERROR( result.Populate([&](absl::Span result_index) { @@ -1227,13 +1353,30 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } // Accumulates resulting product along the contracted dimension. - for (int64 i = 0; i < contracted_dimension_size; ++i) { - lhs_index[lhs_contracting_dimension] = i; - rhs_index[rhs_contracting_dimension] = i; + absl::InlinedVector accumulate_index( + accumulate_index_sizes.size(), 0); + for (int64 k = 0; k < total_contraction_size; k++) { + for (int64 i = 0; i < accumulate_index_sizes.size(); ++i) { + *(accumulate_index_locations[i].first) = accumulate_index[i]; + *(accumulate_index_locations[i].second) = accumulate_index[i]; + } result_val += static_cast(lhs_literal.Get(lhs_index)) * static_cast(rhs_literal.Get(rhs_index)); + + // If there are no contracting dimension accumulate_index_sizes is + // empty, do not try to count down from -1 to 0 since it is and + // infinite loop. + if (!accumulate_index_sizes.empty()) { + for (int64 i = accumulate_index_sizes.size() - 1; i >= 0; --i) { + int64 value = ++accumulate_index[i]; + if (value != accumulate_index_sizes[i]) { + break; + } + accumulate_index[i] = 0; + } + } } return static_cast(result_val); @@ -1244,10 +1387,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandlePad(HloInstruction* pad) override { - CHECK(ShapeUtil::IsArray(pad->operand(0)->shape())); + CHECK(pad->operand(0)->shape().IsArray()); // Padding value must be scalar. CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); - CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), + CHECK_EQ(pad->operand(0)->shape().rank(), pad->padding_config().dimensions_size()); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, @@ -1270,9 +1413,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); - std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), - 0); - std::vector target_index(ShapeUtil::Rank(result.shape()), 0); + std::vector input_index(evaluated_operand.shape().rank(), 0); + std::vector target_index(result.shape().rank(), 0); // Loop through each element of the operand, assign them to the // corresponding index of the resulting padded literal. @@ -1315,10 +1457,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto operand = dynamic_slice->operand(0); auto start_indices = dynamic_slice->operand(1); auto result_shape = dynamic_slice->shape(); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferDynamicSliceShape( - operand->shape(), start_indices->shape(), - dynamic_slice->dynamic_slice_sizes())); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferDynamicSliceShape( + operand->shape(), + Cast(dynamic_slice)->index_shapes(), + dynamic_slice->dynamic_slice_sizes())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1327,33 +1471,39 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { primitive_util::IsIntegralType(start_indices->shape().element_type())); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); switch (start_indices->shape().element_type()) { case S32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; case S64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; case U32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; case U64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; default: LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " @@ -1373,7 +1523,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( auto inferred_return_shape, ShapeInference::InferDynamicUpdateSliceShape( - operand->shape(), update->shape(), start_indices->shape())); + operand->shape(), update->shape(), + Cast(dynamic_update_slice) + ->index_shapes())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1384,33 +1536,39 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); switch (start_indices->shape().element_type()) { case S32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; case S64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; case U32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; case U64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; default: LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " @@ -1447,7 +1605,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Literal computed_result = - embedded_evaluator.Evaluate(*computation, arg_literals) + embedded_evaluator.Evaluate(*computation, arg_literals) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on // the same computation. @@ -1505,6 +1663,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); break; } + case C128: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } default: LOG(FATAL) << "HandleMap: unhandled primitive type for " "input operand: " @@ -1515,80 +1677,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleSort(HloInstruction* sort) { - auto keys = sort->operand(0); - TF_RET_CHECK(sort->operand_count() == 1) - << "Typed visitor does not support key-value sort"; - - const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); - int64 sort_dim = sort->dimensions(0); - int64 sort_dim_elements = keys->shape().dimensions(sort_dim); - int64 rank = ShapeUtil::Rank(keys->shape()); - if (rank == 0) { - // Nothing to sort. - parent_->evaluated_[sort] = keys_literal.Clone(); - return Status::OK(); - } - Literal result_literal(keys_literal.shape()); - std::vector zero_base(rank, 0); - std::vector increment(rank, 1); - increment[sort_dim] = sort_dim_elements; - // Iterate through each dimension except 'sort_dim'. - TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()), - increment, [&](absl::Span indices) -> StatusOr { - // Extract a slice from the literal that corresponds to exactly the - // row in dimension 'sort_dim'. - std::vector limit_indices(indices.begin(), indices.end()); - std::for_each(limit_indices.begin(), limit_indices.end(), - [](int64& index) { ++index; }); - limit_indices[sort_dim] = sort_dim_elements; - TF_ASSIGN_OR_RETURN(auto row_to_sort, - keys_literal.Slice(indices, limit_indices) - .Reshape({sort_dim_elements})); - const auto& row_data = row_to_sort.data(); - - std::vector result_data(row_data.begin(), row_data.end()); - std::stable_sort(result_data.begin(), result_data.end(), - [](const NativeT& a, const NativeT& b) { - return SafeLess(a, b); - }); - Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(), - {sort_dim_elements})); - sorted_row.PopulateR1(absl::Span(result_data)); - std::vector slice_dimensions(rank, 1); - slice_dimensions[sort_dim] = sort_dim_elements; - TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped, - sorted_row.Reshape(slice_dimensions)); - std::vector start_indices(rank, 0); - TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( - sorted_row_reshaped, start_indices, indices, slice_dimensions)); - return true; - })); - parent_->evaluated_[sort] = std::move(result_literal); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleSort(HloInstruction* sort) { - return InvalidArgument("Unsupported type for Sort"); - } - Status HandleSort(HloInstruction* sort) override { - return HandleSort(sort); + return UnsupportedTypeError(sort); } Status HandleReduce(HloInstruction* hlo) override { HloReduceInstruction* reduce = Cast(hlo); int64 num_args = reduce->inputs().size(); - bool has_tuple_output = ShapeUtil::IsTuple(reduce->shape()); + bool has_tuple_output = reduce->shape().IsTuple(); absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); @@ -1619,7 +1715,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // All args and results have the same dimensions, so pick an arbitrary one. const Shape& arg_shape = arg_literals[0]->shape(); - const Shape& result_shape = ShapeUtil::IsTuple(reduce->shape()) + const Shape& result_shape = reduce->shape().IsTuple() ? reduce->shape().tuple_shapes(0) : reduce->shape(); const auto arg_dimensions = AsInt64Slice(arg_shape.dimensions()); @@ -1708,7 +1804,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { [](Literal& literal) { return &literal; }); TF_ASSIGN_OR_RETURN(Literal computed_result, - embedded_evaluator.Evaluate( + embedded_evaluator.Evaluate( *function, embedded_operands_ptrs)); // Clear visit states so that we can use the evaluator again on // the same computation. @@ -1786,7 +1882,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); - int64 rank = ShapeUtil::Rank(operand_literal.shape()); + int64 rank = operand_literal.shape().rank(); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); DimensionVector source_index(rank, 0); @@ -1824,8 +1920,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_val_literal.Set({}, *selected_val); Literal computed_result = embedded_evaluator - .Evaluate( - *select, {&selected_val_literal, &curr_val_literal}) + .Evaluate(*select, + {&selected_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); bool selected = !computed_result.Get({}); if (selected) { @@ -1846,9 +1942,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { scattered_literal.Set({}, scattered); Literal computed_result = embedded_evaluator - .Evaluate( - *scatter, - {&source_literal_scatter, &scattered_literal}) + .Evaluate(*scatter, + {&source_literal_scatter, &scattered_literal}) .ConsumeValueOrDie(); result.Set(operand_index, computed_result.Get({})); // Clear visit states so that the we can use the evaluator again @@ -1898,7 +1993,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { operand->shape().element_type(), window_dimension_sizes); DimensionVector window_index(window.dimensions_size()); - DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); + DimensionVector operand_index(operand_literal.shape().rank()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); Literal result(reduce_window->shape()); @@ -1922,8 +2017,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0(result_val); Literal computed_result = embedded_evaluator - .Evaluate( - *function, {&result_val_literal, &curr_val_literal}) + .Evaluate(*function, + {&result_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again @@ -2285,9 +2380,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0(updates.Get(update_index)); Literal updated_result = embedded_evaluator - .Evaluate( - *scatter->to_apply(), - {&result_value_literal, &update_value_literal}) + .Evaluate(*scatter->to_apply(), + {&result_value_literal, &update_value_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on the // same computation. @@ -2329,7 +2423,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); - const int64 rank = ShapeUtil::Rank(operand->shape()); + const int64 rank = operand->shape().rank(); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); auto func = [&](absl::Span out_index) { DimensionVector operand_index(rank); @@ -2357,7 +2451,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same::value || std::is_same::value)>::type* = nullptr> Status HandleClz(HloInstruction* clz) { - return InvalidArgument("Unsupported type for Clz"); + return UnsupportedTypeError(clz); } template ::value || is_complex_t::value>::type* = nullptr> Status HandleSin(HloInstruction* sin) { - return InvalidArgument("Unsupported type for Sin"); + return UnsupportedTypeError(sin); } Status HandleSin(HloInstruction* sin) override { @@ -2425,7 +2519,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || is_complex_t::value>::type* = nullptr> Status HandleCos(HloInstruction* cos) { - return InvalidArgument("Unsupported type for Cos"); + return UnsupportedTypeError(cos); } Status HandleCos(HloInstruction* cos) override { @@ -2526,7 +2620,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Double not supported for reduce precision"); + return InvalidArgument("Double is not supported for reduce precision"); } template < @@ -2534,46 +2628,172 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || is_complex_t::value>::type* = nullptr> Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Unsupported type for reduce precision"); + return UnsupportedTypeError(reduce_precision); } Status HandleReducePrecision(HloInstruction* reduce_precision) override { return HandleReducePrecision(reduce_precision); } - template ::value || - std::is_floating_point::value>::type* = nullptr> + template < + typename NativeT, + typename std::enable_if< + std::is_same::value || + std::is_same::value || + std::is_integral::value || is_complex_t::value || + std::is_floating_point::value>::type* = nullptr> Status HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); + const int64 iota_size = iota->shape().dimensions(iota->iota_dimension()); // Avoid using std::vector since std::vector does not convert to // absl::Span. - absl::InlinedVector data( - iota->shape().dimensions(iota->iota_dimension())); - std::iota(data.begin(), data.end(), 0); + absl::InlinedVector data(iota_size); + // We don't use std::iota for two reasons: + // + // (1) std:iota does not support bfloat16 and float16. + // + // (2) std::iota saturates for floating point types when the value is not + // representable, but the definition of HLO iota is the value as a + // 64-bit integer cast to the native type. + for (int64 i = 0; i < iota_size; ++i) { + // static_cast is required for Eigen::half (F16). + data[i] = static_cast(i); + } auto result = LiteralUtil::CreateR1(data); - if (ShapeUtil::Rank(iota->shape()) > 1) { + if (iota->shape().rank() > 1) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[iota], result.Broadcast(iota->shape(), {iota->iota_dimension()})); } else { - TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); + TF_RET_CHECK(iota->shape().rank() == 1); parent_->evaluated_[iota] = std::move(result); } return Status::OK(); } + template < + typename NativeT, + typename std::enable_if< + !(std::is_same::value || + std::is_same::value || + std::is_integral::value || is_complex_t::value || + std::is_floating_point::value)>::type* = nullptr> + Status HandleIota(HloInstruction* iota) { + return UnsupportedTypeError(iota); + } + Status HandleIota(HloInstruction* iota) override { + return HandleIota(iota); + } + template ::value || std::is_floating_point::value)>::type* = nullptr> - Status HandleIota(HloInstruction* iota) { - return InvalidArgument("Unsupported type for iota"); + Status HandleRng(HloInstruction* random) { + return UnsupportedTypeError(random); } - Status HandleIota(HloInstruction* iota) override { - return HandleIota(iota); + template ::value)>::type* = nullptr> + Status HandleRng(HloInstruction* random) { + RandomDistribution distribution = random->random_distribution(); + const auto result_shape = random->shape(); + Literal result(result_shape); + + switch (distribution) { + case RNG_UNIFORM: { + const Literal& low = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& high = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + // std::uniform_real_distribution(a, b) can sometimes return a value + // equal to b. Unclear if this is a spec bug or an implementation bug + // or WAI [0] [1] [2]. Anyway for our purposes we want a half-open + // interval, so we have to re-sample if we get `b` out. + // + // [0] https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176 + // [1] https://bugs.llvm.org/show_bug.cgi?id=18767 + // [2] http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524 + auto low_val = low.Get({}); + auto high_val = high.Get({}); + std::uniform_real_distribution generator(low_val, high_val); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span /*indexes*/) { + while (true) { + NativeT v = generator(parent_->engine_); + if (v != high_val) { + return v; + } + } + })); + break; + } + case RNG_NORMAL: { + const Literal& mean = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& stddev = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + std::normal_distribution generator(mean.Get({}), + stddev.Get({})); + + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span /*indexes*/) { + return generator(parent_->engine_); + })); + break; + } + default: + return UnimplementedStrCat("The distribution ", + RandomDistribution_Name(distribution), + " is not implemented."); + } + parent_->evaluated_[random] = std::move(result); + return Status::OK(); + } + template ::value)>::type* = + nullptr> + Status HandleRng(HloInstruction* random) { + RandomDistribution distribution = random->random_distribution(); + const auto result_shape = random->shape(); + Literal result(result_shape); + + switch (distribution) { + case RNG_UNIFORM: { + const Literal& low = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& high = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + // Note std::uniform_int_distribution assumes interval is closed, i.e., + // [low, high], but we want [low, high) instead. Hence high-1 is used as + // the upper range. + std::uniform_int_distribution generator( + low.Get({}), high.Get({}) - 1); + + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span /*indexes*/) { + return static_cast(generator(parent_->engine_)); + })); + break; + } + case RNG_NORMAL: { + return Unimplemented( + "Normal distribution is not supported for integral types."); + } + default: + return UnimplementedStrCat("The distribution ", + RandomDistribution_Name(distribution), + " is not implemented."); + } + parent_->evaluated_[random] = std::move(result); + return Status::OK(); + } + Status HandleRng(HloInstruction* random) override { + return HandleRng(random); } private: @@ -2587,7 +2807,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // // This lets you calculate LI given the multidimensional indices in any order. static DimensionVector MakeDimMultipliers(const Shape& shape) { - DimensionVector v(ShapeUtil::Rank(shape)); + DimensionVector v(shape.rank()); int64 scale = 1; for (auto dim : LayoutUtil::MinorToMajor(shape)) { v[dim] = scale; @@ -2604,7 +2824,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Shape& window_shape, const Window& window, const Shape& base_shape, const absl::Span& window_count_index, const std::function&)>& f) { - const int64 rank = ShapeUtil::Rank(base_shape); + const int64 rank = base_shape.rank(); DimensionVector window_index(rank); std::fill(window_index.begin(), window_index.end(), 0); do { @@ -2635,12 +2855,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr DynamicSlice(const Literal& operand_literal, - const Literal& start_indices_literal, - const Shape& result_shape) { - auto start_indices_typed = start_indices_literal.data(); - std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); + StatusOr DynamicSlice( + const Literal& operand_literal, + absl::Span start_indices, + const Shape& result_shape) { + std::vector start; + + for (HloInstruction* index : start_indices) { + start.push_back( + parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); + } // Clamp the start indices so the slice is in-bounds w.r.t the operand. for (int64 i = 0; i < start.size(); ++i) { @@ -2666,14 +2890,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr DynamicUpdateSlice(const Literal& operand_literal, - const Literal& update_literal, - const Literal& start_indices_literal) { + StatusOr DynamicUpdateSlice( + const Literal& operand_literal, const Literal& update_literal, + absl::Span start_indices) { auto result = operand_literal.Clone(); - auto start_indices_typed = start_indices_literal.data(); - const auto rank = ShapeUtil::Rank(result.shape()); - std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); + const auto rank = result.shape().rank(); + std::vector start; + for (HloInstruction* index : start_indices) { + start.push_back( + parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); + } + // Clamp the update start indices so the slice is in-bounds w.r.t the // operand. for (int64 i = 0; i < rank; ++i) { @@ -2790,6 +3017,7 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex128.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex128.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f48140ee4f6ca9415bef80c83664213109dbf9f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex128.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int16.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int16.cc new file mode 100644 index 0000000000000000000000000000000000000000..e54285a1577a3f3c97fba5ba6c2f969299ab599e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int16.cc @@ -0,0 +1,22 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint16.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint16.cc new file mode 100644 index 0000000000000000000000000000000000000000..cc708952d20a00429944c8388a84a0e610c2f38f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint16.cc @@ -0,0 +1,22 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index 5be9dba3aa49d63c580cd6f5800f608667826b6a..df06cf8c53ec8407f8b44c9126ed4fb5409f8ef3 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -45,7 +45,7 @@ TEST_F(HloExecutionProfileTest, Basic) { auto shape_size_function = [&](const Shape& shape) { const int64 pointer_size = 8; - if (ShapeUtil::IsOpaque(shape)) { + if (shape.IsOpaque()) { return pointer_size; } return ShapeUtil::ByteSizeOf(shape, pointer_size); diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc index c919dbd82d3668c477bf37074f1d56f8cb7d9506..862b2029718bbd802b69d789b66683a4edfa2367 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -25,7 +26,9 @@ namespace xla { namespace { -StatusOr ReplaceGetSize(HloInstruction* instr) { +StatusOr ReplaceGetSize( + HloInstruction* instr, + const DynamicDimensionInference* dynamic_dimension_inference) { if (instr->opcode() != HloOpcode::kGetDimensionSize) { return false; } @@ -36,10 +39,18 @@ StatusOr ReplaceGetSize(HloInstruction* instr) { instr->operand(0)->shape(), instr->dimension())); TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)); TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), U32)); - uint32 size = instr->operand(0)->shape().dimensions(instr->dimension()); - HloInstruction* new_instr = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + HloInstruction* operand = instr->mutable_operand(0); + int64 dim = instr->dimension(); + HloInstruction* dynamic_size = + dynamic_dimension_inference->GetDynamicSize(operand, {}, dim); + if (dynamic_size != nullptr) { + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); + } else { + uint32 size = instr->operand(0)->shape().dimensions(dim); + HloInstruction* new_instr = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + } return true; } @@ -48,10 +59,13 @@ StatusOr ReplaceGetSize(HloInstruction* instr) { StatusOr HloGetDimensionSizeRewriter::Run(HloModule* module) { bool changed = false; HloProto proto; + TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference, + DynamicDimensionInference::Run(module)); *proto.mutable_hlo_module() = module->ToProto(); for (auto* computation : module->computations()) { for (auto instruction : computation->instructions()) { - TF_ASSIGN_OR_RETURN(bool replaced, ReplaceGetSize(instruction)); + TF_ASSIGN_OR_RETURN(bool replaced, + ReplaceGetSize(instruction, &inference)); changed = changed || replaced; } } diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h index 30f44c23a835b3bcc935caaa917e040e07c4e703..9aa79fe66b665c48ec871c4188e44ba2056de3ad 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h @@ -21,7 +21,9 @@ limitations under the License. namespace xla { -// Pass to replace a kGetDimensionSize instruction with a constant instruction. +// Pass to replace a kGetDimensionSize instruction with a hlo instruction +// representing the dynamic size if the dimension is dynamic, otherwise a +// constant instruction representing the static size. class HloGetDimensionSizeRewriter : public HloModulePass { public: absl::string_view name() const override { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 302eca656be53a3cec86ddbf05a7fa3925c5185b..254f66021d70622bfd1c0b2623767ca7ff803e0d 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -24,9 +24,9 @@ limitations under the License. #include #include #include -#include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -38,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" @@ -380,7 +379,7 @@ class HloDotDumper { // Each HloInstruction dumped gets a monotically-increasing node ID. This // must start at 1, because that's where graphviz's accounting starts. int64 next_node_id_ = 1; - std::unordered_map node_ids_; + absl::flat_hash_map node_ids_; // The "root" tag doesn't have an associated HloInstruction pointer, so we // need to store it outside the map. @@ -397,7 +396,7 @@ class HloDotDumper { // Each HloComputation that's emitted gets a monotonically-increasing ID. int64 next_cluster_id_ = 1; - std::unordered_map cluster_ids_; + absl::flat_hash_map cluster_ids_; // Edges to print from Footer(). Edges come at the end because graphviz is // unhappy if an edge from a subcomputation to a node in the outer computation @@ -407,7 +406,7 @@ class HloDotDumper { // When coloring by sharding information, we track the sharding string // representation to color association, by round-robin the color schemes. - std::unordered_map + absl::flat_hash_map sharding_colors_; int64 next_shard_color_ = 0; }; @@ -536,7 +535,12 @@ stylesheet=< } } - return StrFormat(fmt, graph_label, StrJoin(edge_css_rules, "\n")); + // Browsers require that we URI-encode the contents of our data URI. (It + // seems this was a relatively recent change?) In practice, this means that we + // need to escape '#'. + return StrFormat( + fmt, graph_label, + absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}})); } string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); } @@ -561,8 +565,8 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { } // Show the subcomputation if we're showing any of its members. - return std::any_of( - subcomp->instructions().begin(), subcomp->instructions().end(), + return absl::c_any_of( + subcomp->instructions(), [&](const HloInstruction* instr) { return filter_.Show(instr); }); } @@ -733,17 +737,16 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { return true; } const int kMinUsersToOmit = 3; - return instr->opcode() == HloOpcode::kParameter && - ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() && - std::count_if(instr->users().begin(), instr->users().end(), - [&](const HloInstruction* user) { - return filter_.Show(user); - }) > kMinUsersToOmit && - std::all_of(instr->users().begin(), instr->users().end(), - [&](const HloInstruction* user) { - return !filter_.Show(user) || - user->opcode() == HloOpcode::kGetTupleElement; - }); + return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() && + !instr->IsFused() && + absl::c_count_if(instr->users(), + [&](const HloInstruction* user) { + return filter_.Show(user); + }) > kMinUsersToOmit && + absl::c_all_of(instr->users(), [&](const HloInstruction* user) { + return !filter_.Show(user) || + user->opcode() == HloOpcode::kGetTupleElement; + }); } string HloDotDumper::DumpInstruction(const HloInstruction* instr) { @@ -816,7 +819,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // Print the literal value of constants with <= K elements. optional elem_count; - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { elem_count = 1; for (int64 dim : shape.dimensions()) { *elem_count *= dim; @@ -900,12 +903,11 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { // the same color as a parameter. Unless the merged-in parameter is a // parameter to a fusion node that is bound to a constant -- these aren't // "real" parameters from the user's perspective. - if (std::any_of(instr->operands().begin(), instr->operands().end(), - [&](const HloInstruction* operand) { - return operand->opcode() == HloOpcode::kParameter && - ShouldMergeIntoUsers(operand) && - TryGetFusionParameterConstant(operand) == nullptr; - })) { + if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) { + return operand->opcode() == HloOpcode::kParameter && + ShouldMergeIntoUsers(operand) && + TryGetFusionParameterConstant(operand) == nullptr; + })) { return parameter_color; } @@ -951,6 +953,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kRemainder: case HloOpcode::kRng: case HloOpcode::kRoundNearestAfz: + case HloOpcode::kRsqrt: case HloOpcode::kSelect: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: @@ -959,6 +962,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kSin: case HloOpcode::kSlice: case HloOpcode::kSort: + case HloOpcode::kSqrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: // De-emphasize scalar-shaped elementwise ops -- they're generally @@ -1013,6 +1017,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kConvolution: case HloOpcode::kDot: case HloOpcode::kFft: + case HloOpcode::kTriangularSolve: return kDarkBlue; case HloOpcode::kReducePrecision: return kRed; @@ -1030,7 +1035,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kMap: case HloOpcode::kGetDimensionSize: return kGray; - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: @@ -1039,6 +1044,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kRecvDone: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kReplicaId: return kBrown; case HloOpcode::kCall: case HloOpcode::kConditional: @@ -1282,11 +1288,12 @@ namespace { // Gets a NodeFilter that includes roughly all instructions whose distance from // root is <= radius. -NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, - int64 radius) { +NodeFilter MakeNodeRadiusAroundFilter( + const HloInstruction* root, int64 radius, + const absl::flat_hash_set& boundary) { // First, find the neighborhood of nodes with distance from root <= radius. // These nodes are our initial set of "normal" nodes. - std::unordered_map nodes; + absl::flat_hash_map nodes; std::deque> worklist; worklist.push_back({root, 0}); while (!worklist.empty()) { @@ -1299,6 +1306,9 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, if (depth == radius) { continue; } + if (boundary.contains(instr)) { + continue; + } // Traverse into instr's operands. // @@ -1307,7 +1317,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, // are not interesting to the graph at hand. if (instr == root || instr->opcode() != HloOpcode::kTuple) { for (const HloInstruction* operand : instr->operands()) { - if (!nodes.count(operand)) { + if (!nodes.contains(operand)) { worklist.push_back({operand, depth + 1}); } } @@ -1335,7 +1345,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, continue; } for (const HloInstruction* user : instr->users()) { - if (!nodes.count(user)) { + if (!nodes.contains(user)) { worklist.push_back({user, depth + 1}); } } @@ -1344,7 +1354,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, auto is_displayed = [&](const HloInstruction* instr) { // Constants are displayed inline with their users; they're never omitted. // Nodes in subcomputations are always shown. - return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant || + return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant || instr->parent() != root->parent(); }; @@ -1355,12 +1365,11 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, NodeFilterResult& filter_result = kv.second; const auto& operands = instr->operands(); - if (std::any_of(operands.begin(), operands.end(), is_displayed) && - !std::all_of(operands.begin(), operands.end(), is_displayed)) { + if (absl::c_any_of(operands, is_displayed) && + !absl::c_all_of(operands, is_displayed)) { // Mark nodes with some operands omitted appropriately. filter_result = kSomeOperandsOmitted; - } else if (!operands.empty() && - std::none_of(operands.begin(), operands.end(), is_displayed)) { + } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) { // Mark nodes with *all* operands omitted appropriately. filter_result = kOmitNodeOperands; } @@ -1368,8 +1377,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their // users made it into the graph. if (filter_result == kSomeUsersOmitted && - std::all_of(instr->users().begin(), instr->users().end(), - is_displayed)) { + absl::c_all_of(instr->users(), is_displayed)) { filter_result = kNormalNode; } } @@ -1449,9 +1457,6 @@ string SaveGraph(const string& graph, case GraphRendererInterface::DOT_GRAPH: file_extension = ".dot"; break; - case GraphRendererInterface::TF_GRAPHDEF: - file_extension = ".pbtxt"; - break; } string path = JoinPath(dest_path, StrCat("hlo_graph_", output_num++, ".")); auto status = Status::OK(); @@ -1474,39 +1479,42 @@ string ExportGraph(const string& graph, GraphRendererInterface::GraphKind graph_kind, const DebugOptions& debug_options) { string path = debug_options.xla_hlo_graph_path(); - if (!path.empty()) { + if (!path.empty() && !debug_options.xla_hlo_dump_as_html()) { return SaveGraph(graph, graph_kind, path); } else { auto graph_renderer = GraphRendererRegistry::Default()->GetDefaultRenderer(); CHECK(graph_renderer != nullptr) << "No registered renderer for the HLO graph. " - "Use --xla_hlo_graph_path=PATH to export to local file system"; + "Use --xla_hlo_graph_path=PATH --xla_hlo_dump_as_html=false to " + "export to local file system"; return graph_renderer->RenderGraph(graph, graph_kind, debug_options); } } } // namespace +string HloComputationToDotGraph(const HloComputation& computation, + const DotGraphOptions& options) { + DebugOptions default_debug_options; + return HloDotDumper(&computation, options.label, + options.debug_options ? *options.debug_options + : default_debug_options, + options.show_backend_config, options.profile, + NodeFilter()) + .Dump(); +} + string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, bool show_backend_config) { GraphRendererInterface::GraphKind graph_kind; - string graph; - if (debug_options.xla_hlo_dump_as_graphdef()) { - HloTfGraphBuilder builder(debug_options); - TF_CHECK_OK(builder.AddComputation(computation)); - CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(), - &graph)); - graph_kind = GraphRendererInterface::TF_GRAPHDEF; - } else { - graph = - HloDotDumper(&computation, label, debug_options, show_backend_config, - hlo_execution_profile, NodeFilter()) - .Dump(); - graph_kind = GraphRendererInterface::DOT_GRAPH; - } + string graph = + HloDotDumper(&computation, label, debug_options, show_backend_config, + hlo_execution_profile, NodeFilter()) + .Dump(); + graph_kind = GraphRendererInterface::DOT_GRAPH; string graph_url = ExportGraph(graph, graph_kind, debug_options); LOG(INFO) << "computation " << computation.name() << " [" << label @@ -1514,12 +1522,13 @@ string DumpGraph(const HloComputation& computation, const string& label, return graph_url; } -string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_backend_config) { +string DumpNeighborhoodAround( + const HloInstruction& node, int radius, bool show_backend_config, + const absl::flat_hash_set& boundary) { auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); - NodeFilter filter = MakeNodeRadiusAroundFilter(&node, radius); + NodeFilter filter = MakeNodeRadiusAroundFilter(&node, radius, boundary); string graph = HloDotDumper(node.parent(), label, debug_options, show_backend_config, /*profile=*/nullptr, filter) @@ -1589,5 +1598,145 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, return graph_url; } +string WrapDotInHTML(const string& dot) { + static const char html_prefix[] = R"html( + + + + + + + + + + + +
+ + + +)html"; + + return html_prefix + dot + html_suffix; +} + +string RenderDotAsHTMLFile(const string& dot, + const DebugOptions& debug_options) { + string html = WrapDotInHTML(dot); + + auto env = tensorflow::Env::Default(); + std::vector dirs; + string output_dir = debug_options.xla_hlo_graph_path(); + if (output_dir.empty()) { + env->GetLocalTempDirectories(&dirs); + } else { + dirs.push_back(output_dir); + } + // Try each directory, as they might be full, have inappropriate + // permissions or have different problems at times. + string output; + for (const string& dir : dirs) { + string filename = tensorflow::io::JoinPath(dir, "graph-"); + if (env->CreateUniqueFileName(&filename, ".html")) { + output = filename; + break; + } + } + if (output.empty()) { + LOG(FATAL) << "Failed to create unique output file name."; + } + TF_CHECK_OK(tensorflow::WriteStringToFile(env, output, html)); + return "file://" + output; +} + } // namespace hlo_graph_dumper } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index de1eefab776f9c3d2c73959a5cd267e938a78a32..563cea42371d370b4c9ea739418692fd74dca799 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -26,13 +26,23 @@ limitations under the License. namespace xla { namespace hlo_graph_dumper { +// Converts a HLO module to a DOT (graphviz) graph. Returns the dot graph as +// a string. +struct DotGraphOptions { + absl::string_view label; + const DebugOptions* debug_options = nullptr; + const HloExecutionProfile* profile = nullptr; + bool show_backend_config = false; +}; +string HloComputationToDotGraph(const HloComputation& computation, + const DotGraphOptions& options); + // Abstract interface for classes that render HLO graphs (e.g. DOT graph, -// tensorflow GraphDef). +// tensorflow GraphDef) to files or services. class GraphRendererInterface { public: enum GraphKind { DOT_GRAPH, - TF_GRAPHDEF, }; virtual ~GraphRendererInterface() = default; @@ -63,8 +73,12 @@ string DumpGraph(const HloComputation& computation, const string& label, // The number of nodes dumped is controlled by the radius parameter, which // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. -string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_backend_config = false); +// +// The optional boundary specifies a set of boundary nodes, beyond which nodes +// will be omitted even if they are within the radius. +string DumpNeighborhoodAround( + const HloInstruction& node, int radius, bool show_backend_config = false, + const absl::flat_hash_set& boundary = {}); // Dumps nodes on any of the paths from `from` to `to`. If there are more than // max_nodes on all paths, restricts to the max_nodes nodes on the shortest @@ -81,6 +95,12 @@ string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, void DumpText(const HloModule& module, const string& label, const string& directory_path, bool do_prefix = true); +// Renders DOT graph as inline SVG and saves it in an HTML file in a temprary +// directory or directory specified via --xla_hlo_graph_path. Returns the file +// URI pointing to the file. +string RenderDotAsHTMLFile(const string& dot, + const DebugOptions& debug_options); + // Graph renderers may be added using a registration mechanism, e.g.: // XLA_REGISTER_GRAPH_RENDERER(AGraphRendererClass, 100) // The renderer with the highest numeric priority value is used. diff --git a/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc b/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc new file mode 100644 index 0000000000000000000000000000000000000000..84c4cf18df69816c611f4eb159ba247320ebc20e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implementation of an DOT graph renderer that uses Javascript to render DOT to +// SVG in a browser. + +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { +namespace hlo_graph_dumper { +namespace { + +class GraphHtmlRenderer : public GraphRendererInterface { + public: + string RenderGraph(const string& graph, GraphKind graph_kind, + const DebugOptions& debug_options) override { + switch (graph_kind) { + case DOT_GRAPH: + return RenderDotAsHTMLFile(graph, debug_options); + default: + LOG(FATAL) << "Only DOT graphs can be rendered"; + } + } +}; + +XLA_REGISTER_GRAPH_RENDERER(GraphHtmlRenderer); + +} // namespace +} // namespace hlo_graph_dumper +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc index 6e1597fd03db0a78aa560340b7b9b64fe500df0c..b01c00121b3363630b83a1e49d0027a66f3a9e1a 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc @@ -17,22 +17,34 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" namespace xla { + +bool HloInputOutputAliasConfig::OutputHasAlias( + const ShapeIndex& output_index) const { + return alias_.element(output_index).has_value(); +} + Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { + const ShapeIndex& param_index, + AliasKind kind) { + TF_RET_CHECK(kind == AliasKind::kUserAlias || kind == AliasKind::kSystemAlias) + << kind; TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)) << absl::StrCat("Tring to set up alias at ", output_index.ToString(), " which is an invalid index for shape ", ShapeUtil::HumanString(alias_.shape())); + TF_RET_CHECK(param_number >= 0) << param_number; + TF_RET_CHECK(!OutputHasAlias(output_index)) + << "Output index " << output_index << " already has an alias setup"; // Output can't be aliased with multiple parameters. TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat( "Trying to set up output alias for param %lld at %s but failed: output " "index %s is already aliased with param %lld at %s", param_number, param_index.ToString(), output_index.ToString(), - alias_.element(output_index)->first, - alias_.element(output_index)->second.ToString()); + alias_.element(output_index)->parameter_number, + alias_.element(output_index)->parameter_index.ToString()); (*alias_.mutable_element(output_index)) = - std::make_pair(param_number, param_index); + Alias(kind, param_number, param_index); VLOG(4) << "Set up alias between output index " << output_index.ToString() << " and parameter " << param_index << " at index " << param_index.ToString(); @@ -42,15 +54,24 @@ Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const { HloInputOutputAliasProto result; alias_.ForEachElement( - [&](const ShapeIndex& index, - const absl::optional>& data) { + [&](const ShapeIndex& index, const absl::optional& data) { if (data) { HloInputOutputAliasProto::AliasEntryProto entry; + switch (data->kind) { + case AliasKind::kUserAlias: + entry.set_kind(HloInputOutputAliasProto::USER_ALIAS); + break; + case AliasKind::kSystemAlias: + entry.set_kind(HloInputOutputAliasProto::SYSTEM_ALIAS); + break; + default: + LOG(FATAL) << "Unknown alias kind " << data->kind; + } for (int64 i : index) { entry.add_output_shape_index(i); } - entry.set_parameter_number(data->first); - for (int64 i : data->second) { + entry.set_parameter_number(data->parameter_number); + for (int64 i : data->parameter_index) { entry.add_parameter_shape_index(i); } result.add_entries()->Swap(&entry); @@ -66,14 +87,18 @@ StatusOr HloInputOutputAliasConfig::CreateFromProto( proto.entries()) { ShapeIndex output_index(entry.output_shape_index().begin(), entry.output_shape_index().end()); - int64 param_number = entry.parameter_number(); ShapeIndex param_index(entry.parameter_shape_index().begin(), entry.parameter_shape_index().end()); + // Handle backward compatibility with existing protos, which only knew of + // system aliases. + AliasKind kind = AliasKind::kSystemAlias; + if (entry.kind() == HloInputOutputAliasProto::USER_ALIAS) { + kind = AliasKind::kUserAlias; + } TF_RETURN_IF_ERROR( - result.SetUpAlias(output_index, param_number, param_index)); + result.SetUpAlias(output_index, param_number, param_index, kind)); } - return result; } @@ -81,45 +106,44 @@ string HloInputOutputAliasConfig::ToString() const { std::vector pieces; pieces.push_back("HloInputOutputAliasConfig"); - ForEachAlias([&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { + ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) { + const char* kind = alias.kind == AliasKind::kUserAlias ? "USER" : "SYSTEM"; pieces.push_back(absl::StrFormat( - " OutputIndex %s is aliased with parameter %lld at %s:", - output_index.ToString(), param_number, param_index.ToString())); + " OutputIndex %s is aliased (kind=%s) with parameter %lld at %s:", + output_index.ToString(), kind, alias.parameter_number, + alias.parameter_index.ToString())); }); - return absl::StrJoin(pieces, "\n"); } -bool HloInputOutputAliasConfig::ParameterHasAlias( +HloInputOutputAliasConfig::AliasKind +HloInputOutputAliasConfig::ParameterAliasKind( int64 param_number, const ShapeIndex& param_index) const { - bool output = false; + AliasKind kind = AliasKind::kNoAlias; alias_.ForEachElement( - [&](const xla::ShapeIndex&, - absl::optional> alias) { - if (alias && alias->first == param_number && - alias->second == param_index) { - output = true; + [&](const xla::ShapeIndex&, absl::optional alias) { + if (alias && alias->parameter_number == param_number && + alias->parameter_index == param_index) { + kind = alias->kind; } }); - return output; + return kind; } absl::optional HloInputOutputAliasConfig::GetAliasedOutput( int64 param_number, const ShapeIndex& param_index) const { absl::optional output; alias_.ForEachElement( - [&](const xla::ShapeIndex& output_index, - absl::optional> alias) { - if (alias && alias->first == param_number && - alias->second == param_index) { + [&](const xla::ShapeIndex& output_index, absl::optional alias) { + if (alias && alias->parameter_number == param_number && + alias->parameter_index == param_index) { output = output_index; } }); return output; } -absl::optional> +absl::optional HloInputOutputAliasConfig::GetAliasedParameter( const ShapeIndex& output_index) const { CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)); @@ -128,10 +152,9 @@ HloInputOutputAliasConfig::GetAliasedParameter( void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const { alias_.ForEachElement( - [&](const ShapeIndex& output_index, - absl::optional> aliased) { + [&](const ShapeIndex& output_index, absl::optional aliased) { if (aliased) { - fn(output_index, aliased->first, aliased->second); + fn(output_index, *aliased); } }); } @@ -139,10 +162,9 @@ void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const { Status HloInputOutputAliasConfig::ForEachAliasWithStatus( AliasFnWithStatus fn) const { return alias_.ForEachElementWithStatus( - [&](const ShapeIndex& output_index, - absl::optional> aliased) { + [&](const ShapeIndex& output_index, absl::optional aliased) { if (aliased) { - TF_RETURN_IF_ERROR(fn(output_index, aliased->first, aliased->second)); + TF_RETURN_IF_ERROR(fn(output_index, *aliased)); } return Status::OK(); }); @@ -158,20 +180,19 @@ Status HloInputOutputAliasConfig::Verify( param_has_seen.emplace_back(param->shape()); } return ForEachAliasWithStatus([&](const ShapeIndex& output_index, - int64 param_number, - const ShapeIndex& param_index) -> Status { + const Alias& alias) -> Status { const HloInstruction* root = entry->root_instruction(); - TF_RET_CHECK(0 <= param_number); - TF_RET_CHECK(entry->num_parameters() > param_number); + TF_RET_CHECK(0 <= alias.parameter_number); + TF_RET_CHECK(entry->num_parameters() > alias.parameter_number); const Shape& param_shape = - entry->parameter_instruction(param_number)->shape(); + entry->parameter_instruction(alias.parameter_number)->shape(); const Shape& output_shape = root->shape(); - TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, param_index)); + TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, alias.parameter_index)); TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index)); const Shape& param_subshape = - ShapeUtil::GetSubshape(param_shape, param_index); + ShapeUtil::GetSubshape(param_shape, alias.parameter_index); const Shape& output_subshape = ShapeUtil::GetSubshape(output_shape, output_index); TF_RET_CHECK(LayoutUtil::IsDenseArray(param_subshape)); @@ -182,19 +203,20 @@ Status HloInputOutputAliasConfig::Verify( "Expected aliased input %lld at index %s and output at index %s to " "have the same size. Input sub-shape is %s with size %lld, output " "sub-shape is %s with size %lld", - param_number, param_index.ToString(), output_index.ToString(), + alias.parameter_number, alias.parameter_index.ToString(), + output_index.ToString(), ShapeUtil::HumanStringWithLayout(param_subshape), size_func(param_subshape), ShapeUtil::HumanStringWithLayout(output_subshape), size_func(output_subshape)); } - // Check each param_number and param_index pair only show up once. No - // input can be aliased with output buffers. - TF_RET_CHECK(param_has_seen[param_number].element(param_index) == false); - - *(param_has_seen[param_number].mutable_element(param_index)) = true; - + // Check each alias.parameter_number and alias.parameter_index pair only + // show up once. No input can be aliased with output buffers. + TF_RET_CHECK(param_has_seen[alias.parameter_number].element( + alias.parameter_index) == false); + *(param_has_seen[alias.parameter_number].mutable_element( + alias.parameter_index)) = true; return Status::OK(); }); } diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h index 439676b1546c4af7f781fb80bccffd5248309b0f..cd13c7a3ac7afe03fb99ed3114bdc6ac0f8ad6a7 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_tree.h" @@ -31,21 +32,54 @@ class HloModule; // parameter index in the entry computation. class HloInputOutputAliasConfig { public: + // The kind of aliases which can be set. A kUserAlias is one setup at + // compilation time by the user, and has to be respected. A kSystemAlias one + // might be setup by the compiler, if it decides it is convenient to do so. + enum AliasKind { + kNoAlias, + kUserAlias, + kSystemAlias, + }; + + // Defines the alias information for a given output buffer. A given output + // buffer shape index can refer only to one parameter+index. + struct Alias { + Alias(AliasKind kind, int64 parameter_number, ShapeIndex parameter_index) + : kind(kind), + parameter_number(parameter_number), + parameter_index(std::move(parameter_index)) {} + + AliasKind kind; + int64 parameter_number; + ShapeIndex parameter_index; + }; + HloInputOutputAliasConfig() = default; - explicit HloInputOutputAliasConfig(Shape shape) : alias_(shape) {} + explicit HloInputOutputAliasConfig(Shape output_shape) + : alias_(output_shape) {} virtual ~HloInputOutputAliasConfig() = default; // Sets up alias config from `output_index` to `param_index` at // `param_number`. Status SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index); + const ShapeIndex& param_index, AliasKind kind); + + // Returns the kind of alias for the given parameter number and parameter + // index. If no alias exists, AliasKind::kNoAlias is returned. + AliasKind ParameterAliasKind(int64 param_number, + const ShapeIndex& param_index) const; // Returns true if the given parameter is aliased with one of the output // buffers. bool ParameterHasAlias(int64 param_number, - const ShapeIndex& param_index) const; + const ShapeIndex& param_index) const { + return ParameterAliasKind(param_number, param_index) != AliasKind::kNoAlias; + } + + // Checks whether the provided output index has already been aliased. + bool OutputHasAlias(const ShapeIndex& output_index) const; // (De)Serializes an HloInputOutoutAliasConfig to/from an // HloInputOutoutAliasProto. @@ -63,19 +97,17 @@ class HloInputOutputAliasConfig { // Returns the number of parameter and index of the parameter buffer that the // given output buffer index is aliased with. A nullopt is returned if there // is no parameter is aliased with the specific output. - absl::optional> GetAliasedParameter( + absl::optional GetAliasedParameter( const ShapeIndex& output_index) const; using AliasFn = - std::function; + std::function; // Iterates through each aliased output and input. void ForEachAlias(AliasFn fn) const; using AliasFnWithStatus = - std::function; + std::function; // Verifies that the given config is valid for the given module. // Specifically, the config's input and output should be in-bound and size of @@ -90,9 +122,10 @@ class HloInputOutputAliasConfig { private: // A ShapeTree which indicates the list of buffers that's expected to be // aliased. The key on this shape tree represents the output index. The value - // is a pair of parameter number and index into the buffer. If the value is - // nullopt, it means there is no parameter aliasing for this output. - ShapeTree>> alias_; + // is an Alias data structure which defines the input parameter coordinates. + // If the value is nullopt, it means there is no parameter aliasing for this + // output. + ShapeTree> alias_; }; std::ostream& operator<<(std::ostream& out, diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc index aeb9b0fdc8b6cca87731a2d4aae25120af6c3215..265bfdf7f989b0821a98c1f774cb408b78f348fe 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -45,11 +44,12 @@ class HloInputOutputAliasConfigTest : public HloTestBase { EXPECT_TRUE(aliased_output); EXPECT_EQ(aliased_output.value(), output_index); - absl::optional> aliased_param = + absl::optional aliased_param = config.GetAliasedParameter(output_index); EXPECT_TRUE(aliased_param); - EXPECT_EQ(aliased_param.value(), std::make_pair(param_number, param_index)); + EXPECT_EQ(aliased_param->parameter_number, param_number); + EXPECT_EQ(aliased_param->parameter_index, param_index); } void expect_not_aliased(const ShapeIndex& output_index, int64 param_number, @@ -60,11 +60,12 @@ class HloInputOutputAliasConfigTest : public HloTestBase { EXPECT_FALSE(aliased_output && aliased_output == output_index); - absl::optional> aliased_param = + absl::optional aliased_param = config.GetAliasedParameter(output_index); - EXPECT_FALSE(aliased_param && aliased_param->first == param_number && - aliased_param->second == param_index); + EXPECT_FALSE(aliased_param && + aliased_param->parameter_number == param_number && + aliased_param->parameter_index == param_index); } }; @@ -84,8 +85,10 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); expect_aliased(/*output_index=*/{0}, /*param_number=*/1, /*param_index=*/{}, config); @@ -114,11 +117,15 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{0})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{1})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); expect_aliased(/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, config); @@ -149,11 +156,15 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); @@ -176,8 +187,10 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); @@ -200,11 +213,15 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); - ASSERT_IS_NOT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, - /*param_index=*/{})); + ASSERT_IS_NOT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 21b1dbc1676cccd2fe5b331a1f9d6ff5e3a73fcd..6c47bb8935a471743829ae3539c806d0465362c6 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -82,86 +83,70 @@ StatusOr> HloInstruction::CreateFromProto( return computation_map.at(proto.called_computation_ids(index)); }; - TF_RET_CHECK(std::all_of( - proto.operand_ids().begin(), proto.operand_ids().end(), - [&instruction_map](int64 id) { return instruction_map.contains(id); })) + TF_RET_CHECK( + absl::c_all_of(proto.operand_ids(), + [&](int64 id) { return instruction_map.contains(id); })) << proto.name() << " instruction contains invalid operand id(s)"; - TF_RET_CHECK(std::all_of( - proto.called_computation_ids().begin(), - proto.called_computation_ids().end(), - [&computation_map](int64 id) { return computation_map.contains(id); })) + TF_RET_CHECK( + absl::c_all_of(proto.called_computation_ids(), + [&](int64 id) { return computation_map.contains(id); })) << proto.name() << " instruction references invalid computation id(s)"; Shape shape(proto.shape()); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); + absl::optional arity = HloOpcodeArity(opcode); + if (arity) { + TF_RET_CHECK(proto.operand_ids_size() == *arity) + << proto.opcode() << " instruction should have " << *arity + << " operands but sees " << proto.operand_ids_size(); + } + switch (opcode) { // Ops migrated to subclasses. case HloOpcode::kBatchNormTraining: - TF_RET_CHECK(proto.operand_ids_size() == 3) - << "BatchNormTraining instruction should have 3 operands but sees " - << proto.operand_ids_size(); instruction = CreateBatchNormTraining(shape, operands(0), operands(1), operands(2), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormInference: - TF_RET_CHECK(proto.operand_ids_size() == 5) - << "BatchNormInference instruction should have 5 operands but sees " - << proto.operand_ids_size(); instruction = CreateBatchNormInference( shape, operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormGrad: - TF_RET_CHECK(proto.operand_ids_size() == 5) - << "BatchNormGrad instruction should have 5 operands but sees " - << proto.operand_ids_size(); instruction = CreateBatchNormGrad(shape, operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kFft: { - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Fft instruction should have 1 operand but sees " - << proto.operand_ids_size(); std::vector fft_length(proto.fft_length().begin(), proto.fft_length().end()); instruction = CreateFft(shape, operands(0), proto.fft_type(), absl::Span(fft_length)); break; } + case HloOpcode::kTriangularSolve: { + instruction = CreateTriangularSolve(shape, operands(0), operands(1), + proto.triangular_solve_options()); + break; + } case HloOpcode::kSend: - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Send instruction should have 2 operand but sees " - << proto.operand_ids_size(); instruction = CreateSend(operands(0), operands(1), proto.channel_id(), proto.is_host_transfer()); break; case HloOpcode::kSendDone: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "SendDone instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateSendDone(operands(0), proto.is_host_transfer()); break; case HloOpcode::kRecv: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Recv instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateRecv(shape.tuple_shapes(0), operands(0), proto.channel_id(), proto.is_host_transfer()); break; case HloOpcode::kRecvDone: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "RecvDone instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateRecvDone(operands(0), proto.is_host_transfer()); break; case HloOpcode::kReverse: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Reverse instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateReverse(shape, operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); @@ -201,26 +186,21 @@ StatusOr> HloInstruction::CreateFromProto( << proto.operand_ids_size(); TF_RET_CHECK(proto.dimensions().size() == 1) << "Sort instruction should have 1 dimension"; + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Sort instruction should one called computation but sees " + << proto.called_computation_ids_size(); auto sort_operands = all_operands(); - HloInstruction* keys = sort_operands[0]; - instruction = CreateSort( - shape, proto.dimensions(0), keys, - absl::Span(sort_operands).subspan(1)); + instruction = CreateSort(shape, proto.dimensions(0), all_operands(), + computations(0), proto.is_stable()); break; } case HloOpcode::kTranspose: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Transpose instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateTranspose(shape, operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; case HloOpcode::kBroadcast: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Broadcast instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateBroadcast(shape, operands(0), std::vector(proto.dimensions().begin(), @@ -233,9 +213,6 @@ StatusOr> HloInstruction::CreateFromProto( instruction = CreateMap(shape, all_operands(), computations(0)); break; case HloOpcode::kSlice: { - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Slice instruction should have 1 operand but sees " - << proto.operand_ids_size(); std::vector slice_starts, slice_limits, slice_strides; for (const HloInstructionProto::SliceDimensions& slice_dimensions : proto.slice_dimensions()) { @@ -259,9 +236,6 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kTrace: { - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Trace instruction should have 1 operand but sees " - << proto.operand_ids_size(); TF_RET_CHECK(proto.has_literal()); TF_ASSIGN_OR_RETURN(auto literal, Literal::CreateFromProto(proto.literal())); @@ -295,37 +269,29 @@ StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kParameter: instruction = CreateParameter(proto.parameter_number(), shape, proto.name()); + if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) { + instruction->set_parameter_replicated_at_leaf_buffers( + proto.parameter_replication().replicated_at_leaf_buffers()); + } break; case HloOpcode::kGetTupleElement: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "GetTupleElement instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateGetTupleElement(shape, operands(0), proto.tuple_index()); break; case HloOpcode::kReducePrecision: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "ReducePrecision instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateReducePrecision( shape, operands(0), proto.exponent_bits(), proto.mantissa_bits()); break; case HloOpcode::kInfeed: { - TF_RET_CHECK(ShapeUtil::IsTuple(shape) && + TF_RET_CHECK(shape.IsTuple() && (ShapeUtil::TupleElementCount(shape) == 2)) << "Infeed should have a tuple shape with 2 operands, but has: " << shape; const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0); - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Infeed instruction should have 1 operand but sees " - << proto.operand_ids_size(); instruction = CreateInfeed(data_shape, operands(0), proto.infeed_config()); } break; case HloOpcode::kOutfeed: { - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Outfeed instruction should have 2 operands but sees " - << proto.operand_ids_size(); Shape outfeed_shape(proto.outfeed_shape()); TF_RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape)); @@ -333,20 +299,20 @@ StatusOr> HloInstruction::CreateFromProto( proto.outfeed_config()); break; } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kAllReduce: { TF_RET_CHECK(proto.called_computation_ids_size() == 1) - << "CrossReplicaSum should have 1 called computation but sees " + << "AllReduce should have 1 called computation but sees " << proto.called_computation_ids_size(); absl::optional all_reduce_id; if (proto.all_reduce_id() > 0) { all_reduce_id = proto.all_reduce_id(); } - instruction = CreateCrossReplicaSum( + instruction = CreateAllReduce( shape, all_operands(), computations(0), /*replica_groups=*/ std::vector(proto.replica_groups().begin(), proto.replica_groups().end()), - /*barrier=*/proto.cross_replica_sum_barrier(), + /*barrier=*/proto.all_reduce_barrier(), /*all_reduce_id=*/all_reduce_id); break; } @@ -359,9 +325,6 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kCollectivePermute: { - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "CollectivePermute instruction should have 1 operand but sees " - << proto.operand_ids_size(); std::vector> source_target_pairs( proto.source_target_pairs_size()); for (int i = 0; i < source_target_pairs.size(); i++) { @@ -372,10 +335,11 @@ StatusOr> HloInstruction::CreateFromProto( CreateCollectivePermute(shape, operands(0), source_target_pairs); break; } + case HloOpcode::kReplicaId: { + instruction = CreateReplicaId(); + break; + } case HloOpcode::kConvolution: { - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Convolution instruction should have 2 operands but sees " - << proto.operand_ids_size(); TF_RET_CHECK(proto.has_window()); TF_RET_CHECK(proto.has_convolution_dimension_numbers()); PrecisionConfig precision_config = proto.precision_config(); @@ -383,14 +347,12 @@ StatusOr> HloInstruction::CreateFromProto( proto.operand_ids_size(), PrecisionConfig::DEFAULT); instruction = CreateConvolve( shape, operands(0), operands(1), - std::max(proto.feature_group_count(), 1), proto.window(), + std::max(proto.feature_group_count(), 1), + std::max(proto.batch_group_count(), 1), proto.window(), proto.convolution_dimension_numbers(), precision_config); break; } case HloOpcode::kReduceWindow: - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "ReduceWindow instruction should have 2 operands but sees " - << proto.operand_ids_size(); TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "ReduceWindow should have 1 called computation but sees " << proto.called_computation_ids_size(); @@ -398,9 +360,6 @@ StatusOr> HloInstruction::CreateFromProto( proto.window(), computations(0)); break; case HloOpcode::kSelectAndScatter: - TF_RET_CHECK(proto.operand_ids_size() == 3) - << "SelectAndScatter instruction should have 3 operands but sees " - << proto.operand_ids_size(); TF_RET_CHECK(proto.called_computation_ids_size() == 2) << "SelectAndScatter should have 2 called computations but sees " << proto.called_computation_ids_size(); @@ -438,29 +397,56 @@ StatusOr> HloInstruction::CreateFromProto( static_cast(instruction.get()) ->set_feature_group_count( std::max(static_cast(proto.feature_group_count()), 1LL)); + static_cast(instruction.get()) + ->set_batch_group_count( + std::max(static_cast(proto.batch_group_count()), 1LL)); break; case HloOpcode::kPad: - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Pad instruction should have 2 operands but sees " - << proto.operand_ids_size(); TF_RET_CHECK(proto.has_padding_config()); instruction = CreatePad(shape, operands(0), operands(1), proto.padding_config()); break; case HloOpcode::kDynamicSlice: { - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "DynamicSlice instruction should have 2 operands but sees " - << proto.operand_ids_size(); std::vector slice_sizes(proto.dynamic_slice_sizes_size()); absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); + TF_RET_CHECK(proto.operand_ids_size() >= 1) + << "DynamicSlice instruction should have at least 1 operands but " + "sees " + << proto.operand_ids_size(); + // TODO(b/118437727): Old form, make the check unconditional. + if (proto.operand_ids_size() != 2 || operands(1)->shape().rank() != 1) { + auto expected_operands = 1 + operands(0)->shape().rank(); + TF_RET_CHECK(proto.operand_ids_size() == expected_operands) + << "DynamicSlice instruction should have " << expected_operands + << " operands, but has " << proto.operand_ids_size(); + } + const auto& operand_vector = all_operands(); + instruction = CreateDynamicSlice( + shape, operands(0), absl::MakeSpan(operand_vector).subspan(1), + slice_sizes); + break; + } + case HloOpcode::kDynamicUpdateSlice: { + TF_RET_CHECK(proto.operand_ids_size() >= 2) + << "DynamicUpdateSlice instruction should have at least 2 operands " + "but sees " + << proto.operand_ids_size(); + // TODO(b/118437727): Old form, make the check unconditional. + if (proto.operand_ids_size() != 3 || operands(2)->shape().rank() != 1) { + auto expected_operands = 2 + operands(0)->shape().rank(); + TF_RET_CHECK(proto.operand_ids_size() == expected_operands) + << "DynamicUpdateSlice instruction should have " + << expected_operands << " operands, but has " + << proto.operand_ids_size(); + } + const auto& operand_vector = all_operands(); instruction = - CreateDynamicSlice(shape, operands(0), operands(1), slice_sizes); + CreateDynamicUpdateSlice(shape, operands(0), operands(1), + absl::MakeSpan(operand_vector).subspan(2)); + break; } case HloOpcode::kGather: { - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Gather instruction should have 2 operands but sees " - << proto.operand_ids_size(); TF_RET_CHECK(proto.has_gather_dimension_numbers()) << "Gather instruction should have GatherDimensionNumbers set."; std::unique_ptr gather_dimension_numbers = @@ -475,9 +461,6 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kScatter: { - TF_RET_CHECK(proto.operand_ids_size() == 3) - << "Scatter instruction should have 3 operands but sees " - << proto.operand_ids_size(); TF_RET_CHECK(proto.has_scatter_dimension_numbers()) << "Scatter instruction should have ScatterDimensionNumbers set."; TF_RET_CHECK(proto.called_computation_ids_size() == 1) @@ -499,9 +482,6 @@ StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kDot: { TF_RET_CHECK(proto.has_dot_dimension_numbers()) << "Dot instruction should have dot_dimension_numbers."; - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Dot instruction should have 2 operands but sees " - << proto.operand_ids_size(); PrecisionConfig precision_config = proto.precision_config(); precision_config.mutable_operand_precision()->Resize( proto.operand_ids_size(), PrecisionConfig::DEFAULT); @@ -511,9 +491,6 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kDomain: { - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Domain instruction should have 1 operands but sees " - << proto.operand_ids_size(); std::shared_ptr entry_hlo_sharding; std::shared_ptr exit_hlo_sharding; if (proto.has_domain_entry_sharding()) { @@ -535,7 +512,6 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kGetDimensionSize: - TF_RET_CHECK(proto.operand_ids_size() == 1); TF_RET_CHECK(proto.dimensions_size() == 1); instruction = CreateGetDimensionSize(shape, operands(0), proto.dimensions(0)); @@ -569,6 +545,11 @@ StatusOr> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); + + TF_RET_CHECK(proto.id() >= 0) + << "Instruction with negative id: " << proto.id(); + TF_RET_CHECK(proto.id() <= INT_MAX) + << "Instruction with id > INT_MAX: " << proto.id(); instruction->unique_id_ = proto.id(); if (proto.has_sharding()) { @@ -619,7 +600,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, absl::Span operands) { if (opcode == HloOpcode::kCopy) { // It is impossible to copy an opaque shape, we don't know how big it is. - CHECK(!ShapeUtil::IsOpaque(shape)); + CHECK(!shape.IsOpaque()); } auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); for (auto operand : operands) { @@ -650,8 +631,10 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: + case HloOpcode::kRsqrt: case HloOpcode::kSign: case HloOpcode::kSin: + case HloOpcode::kSqrt: case HloOpcode::kTanh: break; default: @@ -729,12 +712,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - int64 feature_group_count, const Window& window, + int64 feature_group_count, int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config) { return absl::make_unique( - shape, lhs, rhs, feature_group_count, window, dimension_numbers, - precision_config); + shape, lhs, rhs, feature_group_count, batch_group_count, window, + dimension_numbers, precision_config); } /* static */ std::unique_ptr HloInstruction::CreateFft( @@ -744,6 +727,13 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, fft_length); } +/* static */ std::unique_ptr +HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a, + HloInstruction* b, + const TriangularSolveOptions& options) { + return absl::make_unique(shape, a, b, options); +} + /* static */ std::unique_ptr HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, @@ -761,8 +751,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape, shape, operand, exponent_bits, mantissa_bits); } -/* static */ std::unique_ptr -HloInstruction::CreateCrossReplicaSum( +/* static */ std::unique_ptr HloInstruction::CreateAllReduce( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, @@ -787,6 +776,11 @@ HloInstruction::CreateCollectivePermute( shape, operand, source_target_pairs); } +/* static */ std::unique_ptr HloInstruction::CreateReplicaId() { + return absl::WrapUnique( + new HloInstruction(HloOpcode::kReplicaId, ShapeUtil::MakeShape(U32, {}))); +} + /* static */ std::unique_ptr HloInstruction::CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config) { @@ -903,23 +897,19 @@ HloInstruction::CreateAddDependency(HloInstruction* data_operand, } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( - const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, + const Shape& shape, HloInstruction* operand, + absl::Span start_indices, absl::Span slice_sizes) { return absl::make_unique( shape, operand, start_indices, slice_sizes); } /* static */ std::unique_ptr -HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, - HloInstruction* operand, - HloInstruction* update, - HloInstruction* start_indices) { - auto instruction = absl::WrapUnique( - new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(update); - instruction->AppendOperand(start_indices); - return instruction; +HloInstruction::CreateDynamicUpdateSlice( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + absl::Span start_indices) { + return absl::make_unique( + shape, operand, update, start_indices); } /* static */ std::unique_ptr HloInstruction::CreateConcatenate( @@ -1035,7 +1025,7 @@ HloInstruction::CreateBroadcastSequence( const std::function)>& adder) { CHECK(ShapeUtil::IsScalar(operand->shape()) || - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)); + operand->shape().rank() == output_shape.rank()); Shape broadcast_shape = ShapeUtil::ChangeElementType( output_shape, operand->shape().element_type()); // Do explicit broadcast for scalar. @@ -1051,7 +1041,7 @@ HloInstruction::CreateBroadcastSequence( // Do explicit broadcast for degenerate broadcast. std::vector broadcast_dimensions; std::vector reshaped_dimensions; - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) { + for (int i = 0; i < operand->shape().rank(); i++) { if (operand->shape().dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand->shape().dimensions(i)); @@ -1107,9 +1097,11 @@ HloInstruction::CreateBroadcastSequence( } /* static */ std::unique_ptr HloInstruction::CreateSort( - const Shape& shape, int64 dimension, HloInstruction* keys, - absl::Span values) { - return absl::make_unique(shape, dimension, keys, values); + const Shape& shape, int64 dimension, + absl::Span operands, HloComputation* compare, + bool is_stable) { + return absl::make_unique(shape, dimension, operands, + compare, is_stable); } /* static */ std::unique_ptr HloInstruction::CreateFusion( @@ -1128,7 +1120,7 @@ HloInstruction::CreateBroadcastSequence( void HloInstruction::set_single_sharding(const HloSharding& sharding) { CHECK(!sharding.IsTuple()) << sharding; - if (ShapeUtil::IsTuple(shape())) { + if (shape().IsTuple()) { set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape()))); } else { set_sharding(sharding); @@ -1160,7 +1152,7 @@ bool HloInstruction::HasSideEffectNoRecurse() const { case HloOpcode::kOutfeed: case HloOpcode::kTrace: return true; - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: return all_reduce_id().has_value(); default: return false; @@ -1283,7 +1275,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kParameter: case HloOpcode::kGetTupleElement: case HloOpcode::kReducePrecision: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: @@ -1301,6 +1293,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kDot: case HloOpcode::kDomain: case HloOpcode::kGetDimensionSize: + case HloOpcode::kTriangularSolve: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1321,8 +1314,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: + case HloOpcode::kRsqrt: case HloOpcode::kSign: case HloOpcode::kSin: + case HloOpcode::kSqrt: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); clone = CreateUnary(shape, opcode_, new_operands[0]); @@ -1378,9 +1373,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateReshape(shape, new_operands[0]); break; case HloOpcode::kDynamicUpdateSlice: - CHECK_EQ(new_operands.size(), 3); clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], - new_operands[2]); + new_operands.subspan(2)); break; case HloOpcode::kTuple: clone = CreateTuple(new_operands); @@ -1408,6 +1402,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 2); clone = CreateAddDependency(new_operands[0], new_operands[1]); break; + case HloOpcode::kReplicaId: + CHECK_EQ(new_operands.size(), 0); + clone = CreateReplicaId(); + break; } // SetupDerivedInstruction will setup the precision_config_ field. SetupDerivedInstruction(clone.get()); @@ -1542,12 +1540,10 @@ HloInstruction::InstructionVector HloInstruction::unique_operands() const { Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { TF_RET_CHECK(instruction->parent() == parent()); - if (std::find(control_successors_.begin(), control_successors_.end(), - instruction) == control_successors_.end()) { + if (!absl::c_linear_search(control_successors_, instruction)) { control_successors_.push_back(instruction); - TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(), - instruction->control_predecessors_.end(), - this) == instruction->control_predecessors_.end()); + TF_RET_CHECK( + !absl::c_linear_search(instruction->control_predecessors_, this)); instruction->control_predecessors_.push_back(this); } return Status::OK(); @@ -1679,13 +1675,16 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReal: case HloOpcode::kRemainder: case HloOpcode::kReshape: + case HloOpcode::kReplicaId: case HloOpcode::kRoundNearestAfz: + case HloOpcode::kRsqrt: case HloOpcode::kSelect: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: + case HloOpcode::kSqrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kTuple: @@ -1740,7 +1739,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReducePrecision: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kConvolution: @@ -1754,13 +1753,19 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kDot: case HloOpcode::kDomain: case HloOpcode::kGetDimensionSize: + case HloOpcode::kTriangularSolve: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } return false; } -uint64 HloInstruction::Hash() const { +static uint64 HashOperand(const HloInstruction* hlo) { + return ShapeUtil::Hash(hlo->shape()); +} + +uint64 HloInstruction::Hash( + const std::function& hash_operand) const { using tensorflow::Hash64Combine; uint64 hash_value = Hash64Combine(0, static_cast(opcode())); @@ -1769,7 +1774,7 @@ uint64 HloInstruction::Hash() const { if (!IsCrossModuleAllReduce()) { if (!operands().empty()) { for (size_t i = 0; i < operands().size(); ++i) { - hash_value = Hash64Combine(hash_value, operand(i)->Hash()); + hash_value = Hash64Combine(hash_value, hash_operand(operand(i))); } } } @@ -1778,6 +1783,11 @@ uint64 HloInstruction::Hash() const { return hash_value; } +uint64 HloInstruction::Hash() const { + // Use HashOperand as an argument to prevent non-termination. + return Hash(HashOperand); +} + uint64 HloInstruction::InnerHash() const { return 13; } void HloInstruction::RemoveUser(HloInstruction* user) { @@ -1786,7 +1796,7 @@ void HloInstruction::RemoveUser(HloInstruction* user) { user_set_.erase(set_it); // This is linear in the number of the users, but a vector provides a stable // iteration order and much faster traversal. - auto vec_it = std::find(users_.begin(), users_.end(), user); + auto vec_it = absl::c_find(users_, user); CHECK(vec_it != users_.end()); users_.erase(vec_it); } @@ -1798,14 +1808,17 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, << "this shape: " << ShapeUtil::HumanString(shape()) << ", replacement shape: " << ShapeUtil::HumanString(new_producer->shape()); + return ReplaceUseWithDifferentShape(user, new_producer); +} +Status HloInstruction::ReplaceUseWithDifferentShape( + HloInstruction* user, HloInstruction* new_producer) { VLOG(3) << "Replacing uses of " << name() << " in " << user->name() << " with " << new_producer->name(); RemoveUser(user); - TF_RET_CHECK( - std::count(user->operands_.begin(), user->operands_.end(), this) >= 0); + TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0); std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); @@ -1818,6 +1831,16 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, Status HloInstruction::ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand) { + auto old_operand = operand(operand_num); + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), + new_operand->shape())) + << old_operand->shape() << " is not compatible with " + << new_operand->shape(); + return ReplaceOperandWithDifferentShape(operand_num, new_operand); +} + +Status HloInstruction::ReplaceOperandWithDifferentShape( + int64 operand_num, HloInstruction* new_operand) { TF_RET_CHECK(operand_num >= 0); TF_RET_CHECK(operand_num < operand_count()); HloInstruction* old_operand = mutable_operand(operand_num); @@ -1825,17 +1848,12 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, return Status::OK(); } - TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), - new_operand->shape())) - << old_operand->shape() << " is not compatible with " - << new_operand->shape(); operands_[operand_num] = new_operand; VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with " << new_operand->name() << ", was " << old_operand->name(); - if (std::find(operands_.begin(), operands_.end(), old_operand) == - operands_.end()) { + if (!absl::c_linear_search(operands_, old_operand)) { old_operand->RemoveUser(this); } new_operand->AddUser(this); @@ -1843,6 +1861,14 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, } Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { + TF_RET_CHECK( + ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape())) + << shape() << " is not compatible with " << new_producer->shape(); + return ReplaceAllUsesWithDifferentShape(new_producer); +} + +Status HloInstruction::ReplaceAllUsesWithDifferentShape( + HloInstruction* new_producer) { bool new_producer_is_user = false; for (HloInstruction* user : users()) { if (user == new_producer) { @@ -1867,7 +1893,8 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { AddUser(new_producer); } if (parent_ && parent_->root_instruction() == this) { - parent_->set_root_instruction(new_producer); + parent_->set_root_instruction(new_producer, + /*accept_different_shape=*/true); } return Status::OK(); @@ -1879,8 +1906,9 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kScatter: + case HloOpcode::kSort: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -1898,8 +1926,9 @@ void HloInstruction::set_to_apply(HloComputation* computation) { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kScatter: + case HloOpcode::kSort: CHECK_EQ(called_computations_.size(), 1); called_computations_[0] = computation; break; @@ -2010,8 +2039,10 @@ bool HloInstruction::IsElementwiseImpl( case HloOpcode::kNegate: case HloOpcode::kReal: case HloOpcode::kReducePrecision: + case HloOpcode::kRsqrt: case HloOpcode::kSign: case HloOpcode::kSin: + case HloOpcode::kSqrt: case HloOpcode::kTanh: CHECK_EQ(1, operand_count()); return true; @@ -2056,7 +2087,11 @@ bool HloInstruction::IsElementwiseImpl( } bool HloInstruction::IsCrossModuleAllReduce() const { - return opcode() == HloOpcode::kCrossReplicaSum && all_reduce_id(); + return opcode() == HloOpcode::kAllReduce && all_reduce_id(); +} + +bool HloInstruction::IsCrossReplicaAllReduce() const { + return opcode() == HloOpcode::kAllReduce && !all_reduce_id(); } string HloInstruction::ToStringWithCanonicalNameMap( @@ -2167,8 +2202,9 @@ std::vector HloInstruction::ExtraAttributesToString( } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || opcode() == HloOpcode::kReduce || - opcode() == HloOpcode::kCrossReplicaSum || - opcode() == HloOpcode::kScatter) { + opcode() == HloOpcode::kAllReduce || + opcode() == HloOpcode::kScatter || + opcode() == HloOpcode::kSort) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { @@ -2203,8 +2239,9 @@ std::vector HloInstruction::ExtraAttributesToString( case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kScatter: + case HloOpcode::kSort: extra.push_back( StrCat("to_apply=\n", to_apply()->ToString(new_options))); break; @@ -2400,12 +2437,14 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleConvolution(this); case HloOpcode::kFft: return visitor->HandleFft(this); - case HloOpcode::kCrossReplicaSum: - return visitor->HandleCrossReplicaSum(this); + case HloOpcode::kAllReduce: + return visitor->HandleAllReduce(this); case HloOpcode::kAllToAll: return visitor->HandleAllToAll(this); case HloOpcode::kCollectivePermute: return visitor->HandleCollectivePermute(this); + case HloOpcode::kReplicaId: + return visitor->HandleReplicaId(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -2440,6 +2479,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleCos(this); case HloOpcode::kSin: return visitor->HandleSin(this); + case HloOpcode::kSqrt: + return visitor->HandleSqrt(this); + case HloOpcode::kRsqrt: + return visitor->HandleRsqrt(this); case HloOpcode::kReal: return visitor->HandleReal(this); case HloOpcode::kImag: @@ -2508,6 +2551,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleIota(this); case HloOpcode::kGetDimensionSize: return visitor->HandleGetDimensionSize(this); + case HloOpcode::kTriangularSolve: + return visitor->HandleTriangularSolve(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -2806,7 +2851,7 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { } return UseKind::kReuse; case HloOpcode::kDynamicUpdateSlice: - // Dynamic-update-slice reuses only operand 2 (start_indices). + // Dynamic-update-slice reuses only start_indices. if (i == 0 || i == 1) { return UseKind::kUse; } @@ -2859,10 +2904,10 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding) { bool has_interior_padding = - std::any_of(padding.dimensions().begin(), padding.dimensions().end(), - [](const PaddingConfig::PaddingConfigDimension& dim) { - return dim.interior_padding() != 0; - }); + absl::c_any_of(padding.dimensions(), + [](const PaddingConfig::PaddingConfigDimension& dim) { + return dim.interior_padding() != 0; + }); return StrJoin( padding.dimensions(), "x", [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) { @@ -3219,6 +3264,19 @@ int64 HloInstruction::parameter_number() const { return Cast(this)->parameter_number(); } +void HloInstruction::set_parameter_replicated_at_leaf_buffers( + absl::Span parameter_replicated_at_leaf_buffers) { + return Cast(this) + ->set_parameter_replicated_at_leaf_buffers( + parameter_replicated_at_leaf_buffers); +} + +const absl::optional>& +HloInstruction::parameter_replicated_at_leaf_buffers() const { + return Cast(this) + ->parameter_replicated_at_leaf_buffers(); +} + int64 HloInstruction::tuple_index() const { return Cast(this)->tuple_index(); } @@ -3256,13 +3314,12 @@ HloInstruction::source_target_pairs() const { return Cast(this)->source_target_pairs(); } -string HloInstruction::cross_replica_sum_barrier() const { - return Cast(this)->cross_replica_sum_barrier(); +string HloInstruction::all_reduce_barrier() const { + return Cast(this)->all_reduce_barrier(); } -void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { - return Cast(this)->set_cross_replica_sum_barrier( - barrier); +void HloInstruction::set_all_reduce_barrier(const string& barrier) { + return Cast(this)->set_all_reduce_barrier(barrier); } absl::optional HloInstruction::all_reduce_id() const { @@ -3308,6 +3365,18 @@ void HloInstruction::set_feature_group_count(int64 feature_group_count) { feature_group_count); } +int64 HloInstruction::batch_group_count() const { + if (auto convolution = DynCast(this)) { + return convolution->batch_group_count(); + } + return Cast(this)->batch_group_count(); +} + +void HloInstruction::set_batch_group_count(int64 batch_group_count) { + Cast(this)->set_batch_group_count( + batch_group_count); +} + HloComputation* HloInstruction::select() const { return Cast(this)->select(); } @@ -3364,4 +3433,8 @@ const DomainMetadata& HloInstruction::operand_side_metadata() const { const DomainMetadata& HloInstruction::user_side_metadata() const { return Cast(this)->user_side_metadata(); } + +const TriangularSolveOptions& HloInstruction::triangular_solve_options() const { + return Cast(this)->triangular_solve_options(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a54716217d6bbc5c0601f5d9ff7bf4072a6b30f5..33cbb9a41bab838e02813e75e2ca6327f785b007 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -47,6 +47,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -384,6 +385,14 @@ class HloInstruction { // Creates a random number generation instruction that fills a shape with // random numbers from a given distribution. + // + // The parameters to the instruction are interpreted as follows: + // + // - If `distribution` is RNG_UNIFORM, generates a number in range + // [param0, param1). + // + // - If `distribution` is RNG_NORMAL, generates a normally-distributed value + // with mean `param0` and standard deviation `param1`. static std::unique_ptr CreateRng( const Shape& shape, RandomDistribution distribution, absl::Span parameters); @@ -426,7 +435,7 @@ class HloInstruction { // and window describes how the filter is applied to lhs. static std::unique_ptr CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - int64 feature_group_count, const Window& window, + int64 feature_group_count, int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config); @@ -435,6 +444,10 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, FftType fft_type, absl::Span fft_length); + static std::unique_ptr CreateTriangularSolve( + const Shape& shape, HloInstruction* a, HloInstruction* b, + const TriangularSolveOptions& options); + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch // dimensions specified in 'dimension_numbers'. static std::unique_ptr CreateDot( @@ -462,9 +475,7 @@ class HloInstruction { // `all_reduce_id`: for Allreduce nodes from different modules, if they have // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will // not be applied cross modules. - // - // TODO(b/117564385): Rename this to AllReduce. - static std::unique_ptr CreateCrossReplicaSum( + static std::unique_ptr CreateAllReduce( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, @@ -491,11 +502,14 @@ class HloInstruction { // Data is sent/received according to the (source_replica_id, // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a // target_replica_id in any pair, the output on that replica is a tensor - // conssits of 0(s) in `shape`. + // consists of 0(s) in `shape`. static std::unique_ptr CreateCollectivePermute( const Shape& shape, HloInstruction* operand, const std::vector>& source_target_pairs); + // Creates an instruction that returns a U32 replica ID. + static std::unique_ptr CreateReplicaId(); + // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. static std::unique_ptr CreateConvert(const Shape& shape, @@ -560,13 +574,14 @@ class HloInstruction { // 'slice_sizes'. static std::unique_ptr CreateDynamicSlice( const Shape& shape, HloInstruction* operand, - HloInstruction* start_indices, absl::Span slice_sizes); + absl::Span start_indices, + absl::Span slice_sizes); // Creates a dynamic update slice instruction, which updates a slice // of 'operand' with 'update' and 'start_indices'. static std::unique_ptr CreateDynamicUpdateSlice( const Shape& shape, HloInstruction* operand, HloInstruction* update, - HloInstruction* start_indices); + absl::Span start_indices); // Creates a concatenate instruction, where the operands are concatenated on // the provided dimension. @@ -596,7 +611,6 @@ class HloInstruction { // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, // ..., inputN.value1) // ... - // TODO(b/112040122): Add support to this in HLO passes and in backends. static std::unique_ptr CreateReduce( const Shape& shape, absl::Span operands, absl::Span init_values, @@ -669,10 +683,15 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, absl::Span dimensions); - // Creates a sort op, with a keys operand, and optional values operands. + // Creates a n-ary sort op with a 'compare' computation which is used for + // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters, + // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at + // specific index positions which should be compared, and should return a + // PRED. 'is_stable' specifies whether stable sorting is required. static std::unique_ptr CreateSort( - const Shape& shape, int64 dimension, HloInstruction* keys, - absl::Span values = {}); + const Shape& shape, int64 dimension, + absl::Span operands, HloComputation* compare, + bool is_stable); // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For @@ -909,6 +928,14 @@ class HloInstruction { // information on opcode, shape, operands, and typically a root instruction. // This function returns the same hash value for equivalent HLO instructions, // with respect to HloInstruction::Identical() method. + // + // Uses hash_operand function to compute hash values of its operands. + // At the very top level, hash_operand should be non-recursive to prevent + // non-termination. + uint64 Hash( + const std::function& hash_operand) const; + + // Calls the above method with non-recursive hash_operand function. uint64 Hash() const; // Returns whether the instruction has a constant operand. @@ -922,11 +949,20 @@ class HloInstruction { // operands of it which could be created due to this replacement. Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); - // Replaces the specified operand with new_operand. + // Same as ReplaceUseWith(), but new_producer can have a different shape. + Status ReplaceUseWithDifferentShape(HloInstruction* user, + HloInstruction* new_producer); + + // Replaces the specified operand with new_operand. The old and new operands + // must have compatible shapes ignoring floating-point precision. // // This function does NOT remove duplicated operands even if this instruction // is a fusion, so that the existing operand numbers do not change. - Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand); + Status ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand); + + // Same as ReplaceOperandWith(), but new_operand can have a different shape. + Status ReplaceOperandWithDifferentShape(int64 operand_num, + HloInstruction* new_operand); // Replaces all uses of this instruction with the new producer. If // new_producer is a user of this instruction then new_producer remains a use @@ -935,10 +971,16 @@ class HloInstruction { // If this instruction is the root of its computation, sets the computation's // root to new_producer. // + // The new producer must have a compatible shape ignoring floating-point + // precision. + // // If a user is a fusion instruction, this function will remove any duplicated // operands of it which could be created due to this replacement. Status ReplaceAllUsesWith(HloInstruction* new_producer); + // Same as ReplaceAllUsesWith, but new_producer can have a different shape. + Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer); + // Performs a postorder DFS visit using this node as the root. If // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when // complete. If ignore_control_predecessors is true, instructions only @@ -1174,9 +1216,12 @@ class HloInstruction { // Returns true if this instruction is elementwise on all its operands. bool IsElementwise() const; - // Returns true if this is an cross module all-reduce instrucion. + // Returns true if this is a cross module all-reduce instruction. bool IsCrossModuleAllReduce() const; + // Returns true if this is a cross-replica all-reduce instruction. + bool IsCrossReplicaAllReduce() const; + // Returns true if this elementwise instruction implicitly broadcasts operand // `operand_idx`. // @@ -1218,6 +1263,10 @@ class HloInstruction { // on the instruction's existing name. void UniquifyName(NameUniquer* name_uniquer); + // Clear the unique ID of the instruction so that it can be re-assigned, such + // as for the purpose of compacting the instruction unique IDs. + void ClearUniqueIdInternal() { unique_id_ = -1; } + // Set the unique id for this instruction to "id" void SetUniqueId(int id) { CHECK_EQ(unique_id_, -1); // Should not be assigned already @@ -1251,6 +1300,9 @@ class HloInstruction { backend_config_ = std::move(config_str); } + bool is_default_config() const { return is_default_config_; } + void set_default_config() { is_default_config_ = true; } + // Returns a string representation of a proto in the format used by // raw_backend_config_string. // @@ -1421,6 +1473,15 @@ class HloInstruction { // Delegates to HloParameterInstruction::parameter_number. int64 parameter_number() const; + // Delegates to + // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers. + void set_parameter_replicated_at_leaf_buffers( + absl::Span parameter_replicated_at_leaf_buffers); + + // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers. + const absl::optional>& + parameter_replicated_at_leaf_buffers() const; + // Delegates to HloGetTupleElementInstruction::tuple_index. int64 tuple_index() const; @@ -1448,9 +1509,9 @@ class HloInstruction { // Delegates to HloCollectivePermuteInstruction::source_target_pairs. const std::vector>& source_target_pairs() const; - // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. - string cross_replica_sum_barrier() const; - void set_cross_replica_sum_barrier(const string& barrier); + // Delegates to HloAllReduceInstruction::all_reduce_barrier. + string all_reduce_barrier() const; + void set_all_reduce_barrier(const string& barrier); // Delegates to HloAllReduceInstruction::all_reduce_id. absl::optional all_reduce_id() const; @@ -1484,6 +1545,11 @@ class HloInstruction { void set_feature_group_count(int64 feature_group_count); + // The number of batch groups. Must be a divisor of the input batch dimension + int64 batch_group_count() const; + + void set_batch_group_count(int64 batch_group_count); + // Delegates to HloSelectAndScatterInstruction::select. HloComputation* select() const; @@ -1525,6 +1591,9 @@ class HloInstruction { // Delegates to HloDomainInstruction::user_side_metadata(). const DomainMetadata& user_side_metadata() const; + // Delegates to HloTriangularSolveInstruction::triangular_solve_options(). + const TriangularSolveOptions& triangular_solve_options() const; + // Old methods kept for smooth subclassing transition END. protected: @@ -1691,6 +1760,10 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; + // This field is assigned to true when backend_config_ is assigned to + // a default configuration. + bool is_default_config_ = false; + // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 8048e332cb57747286758b75773b29ba154aa888..35f031f29a7aca8db7ebe2fbcfdcebb7a778d703 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -55,13 +56,13 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { } Status HandleParameter(HloInstruction* parameter) override { - EXPECT_EQ(0, count_.count(parameter)); + EXPECT_FALSE(count_.contains(parameter)); count_[parameter] = GetCountsForNode(parameter); return Status::OK(); } Status HandleConstant(HloInstruction* constant) override { - EXPECT_EQ(0, count_.count(constant)); + EXPECT_FALSE(count_.contains(constant)); count_[constant] = GetCountsForNode(constant); return Status::OK(); } @@ -69,25 +70,25 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { Status HandleAdd(HloInstruction* add) override { auto lhs = add->operand(0); auto rhs = add->operand(1); - EXPECT_EQ(0, count_.count(add)); - EXPECT_GT(count_.count(lhs), 0); - EXPECT_GT(count_.count(rhs), 0); + EXPECT_FALSE(count_.contains(add)); + EXPECT_TRUE(count_.contains(lhs)); + EXPECT_TRUE(count_.contains(rhs)); count_[add] = GetCountsForNode(add); return Status::OK(); } Status HandleNegate(HloInstruction* negate) override { auto operand = negate->operand(0); - EXPECT_EQ(0, count_.count(negate)); - EXPECT_GT(count_.count(operand), 0); + EXPECT_FALSE(count_.contains(negate)); + EXPECT_TRUE(count_.contains(operand)); count_[negate] = GetCountsForNode(negate); return Status::OK(); } Status HandleMap(HloInstruction* map) override { - EXPECT_EQ(0, count_.count(map)); + EXPECT_FALSE(count_.contains(map)); for (HloInstruction* arg : map->operands()) { - EXPECT_GT(count_.count(arg), 0); + EXPECT_TRUE(count_.contains(arg)); } count_[map] = GetCountsForNode(map); return Status::OK(); @@ -96,9 +97,9 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { Status HandleReduce(HloInstruction* reduce) override { auto arg = reduce->operand(0); auto init_value = reduce->operand(1); - EXPECT_EQ(0, count_.count(reduce)); - EXPECT_GT(count_.count(arg), 0); - EXPECT_GT(count_.count(init_value), 0); + EXPECT_FALSE(count_.contains(reduce)); + EXPECT_TRUE(count_.contains(arg)); + EXPECT_TRUE(count_.contains(init_value)); count_[reduce] = GetCountsForNode(reduce); return Status::OK(); } @@ -128,7 +129,7 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { } // Counters for HLOs. Maps HLO to a NumOpsAndUsers. - std::unordered_map count_; + absl::flat_hash_map count_; }; TEST_F(HloInstructionTest, BasicProperties) { @@ -137,7 +138,7 @@ TEST_F(HloInstructionTest, BasicProperties) { EXPECT_EQ(HloOpcode::kParameter, parameter->opcode()); EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32)); EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32)); - EXPECT_EQ(0, parameter->operand_count()); + EXPECT_FALSE(parameter->operand_count()); } TEST_F(HloInstructionTest, UserWithTwoOperands) { @@ -981,9 +982,9 @@ TEST_F(HloInstructionTest, FunctionVisitor) { module->AddEntryComputation(builder.Build()); int visit_num = 0; - std::unordered_map visit_order; + absl::flat_hash_map visit_order; EXPECT_IS_OK(add->Accept([&visit_num, &visit_order](HloInstruction* inst) { - EXPECT_EQ(0, visit_order.count(inst)); + EXPECT_FALSE(visit_order.contains(inst)); visit_order[inst] = visit_num; visit_num++; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 1ea02cf9c03866a598bec0e5356f0eb31ad27755..905a6fe08b4430ad862edf0886a57c9f7e9f7977 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -42,11 +42,9 @@ using absl::StrJoin; bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, const HloInstruction* operand) { std::vector operand_indices = instruction->OperandIndices(operand); - return std::all_of( - operand_indices.begin(), operand_indices.end(), - [instruction](int64 operand_index) { - return instruction->IsElementwiseOnOperand(operand_index); - }); + return absl::c_all_of(operand_indices, [instruction](int64 operand_index) { + return instruction->IsElementwiseOnOperand(operand_index); + }); } string PrecisionConfigToString(const PrecisionConfig& precision_config) { @@ -203,6 +201,57 @@ std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( fft_length_); } +HloTriangularSolveInstruction::HloTriangularSolveInstruction( + const Shape& shape, HloInstruction* a, HloInstruction* b, + const TriangularSolveOptions& options) + : HloInstruction(HloOpcode::kTriangularSolve, shape), + triangular_solve_options_(options) { + AppendOperand(a); + AppendOperand(b); +} + +HloInstructionProto HloTriangularSolveInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_triangular_solve_options() = triangular_solve_options_; + return proto; +} + +std::vector HloTriangularSolveInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return { + StrCat("left_side=", + triangular_solve_options_.left_side() ? "true" : "false"), + StrCat("lower=", triangular_solve_options_.lower() ? "true" : "false"), + StrCat("unit_diagonal=", + triangular_solve_options_.unit_diagonal() ? "true" : "false"), + StrCat("transpose_a=", TriangularSolveOptions_Transpose_Name( + triangular_solve_options_.transpose_a()))}; +} + +bool HloTriangularSolveInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + const auto& options = triangular_solve_options(); + const auto& other_options = casted_other.triangular_solve_options(); + + return options.left_side() == other_options.left_side() && + options.lower() == other_options.lower() && + options.unit_diagonal() == other_options.unit_diagonal() && + options.transpose_a() == other_options.transpose_a(); +} + +std::unique_ptr +HloTriangularSolveInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique( + shape, new_operands[0], new_operands[1], triangular_solve_options()); +} + HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, int64 channel_id, @@ -363,9 +412,9 @@ HloAllReduceInstruction::HloAllReduceInstruction( HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id) - : HloCollectiveInstruction(HloOpcode::kCrossReplicaSum, shape, operands, + : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands, replica_groups), - cross_replica_sum_barrier_(barrier), + all_reduce_barrier_(barrier), all_reduce_id_(all_reduce_id) { AppendComputation(reduce_computation); } @@ -381,16 +430,25 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const { if (all_reduce_id_) { proto.set_all_reduce_id(*all_reduce_id_); } - proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); + proto.set_all_reduce_barrier(all_reduce_barrier_); return proto; } +bool HloAllReduceInstruction::IsNoop() const { + for (auto replica_group : replica_groups()) { + if (replica_group.replica_ids().size() != 1) { + return false; + } + } + return !all_reduce_id(); +} + std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { std::vector result = HloCollectiveInstruction::ExtraAttributesToStringImpl(options); - if (!cross_replica_sum_barrier().empty()) { - result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); + if (!all_reduce_barrier().empty()) { + result.push_back(StrCat("barrier=\"", all_reduce_barrier(), "\"")); } if (all_reduce_id_) { result.push_back(StrCat("all_reduce_id=", *all_reduce_id_)); @@ -405,8 +463,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath( const auto& casted_other = static_cast(other); return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && eq_computations(to_apply(), casted_other.to_apply()) && - cross_replica_sum_barrier() == - casted_other.cross_replica_sum_barrier() && + all_reduce_barrier() == casted_other.all_reduce_barrier() && all_reduce_id() == casted_other.all_reduce_id(); } @@ -415,8 +472,8 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { return absl::make_unique( - shape, new_operands, to_apply(), replica_groups(), - cross_replica_sum_barrier(), all_reduce_id()); + shape, new_operands, to_apply(), replica_groups(), all_reduce_barrier(), + all_reduce_id()); } HloAllToAllInstruction::HloAllToAllInstruction( @@ -603,14 +660,17 @@ std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( dimensions(), to_apply()); } -HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, - HloInstruction* keys, - absl::Span values) - : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) { - AppendOperand(keys); - for (auto* value : values) { +HloSortInstruction::HloSortInstruction( + const Shape& shape, int64 dimension, + absl::Span operands, HloComputation* compare, + bool is_stable) + : HloInstruction(HloOpcode::kSort, shape), + dimensions_({dimension}), + is_stable_(is_stable) { + for (auto* value : operands) { AppendOperand(value); } + AppendComputation(compare); } HloInstructionProto HloSortInstruction::ToProto() const { @@ -618,12 +678,18 @@ HloInstructionProto HloSortInstruction::ToProto() const { for (int64 dimension : dimensions_) { proto.add_dimensions(dimension); } + proto.set_is_stable(is_stable()); return proto; } std::vector HloSortInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; + std::vector attrs; + attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}")); + if (is_stable()) { + attrs.push_back("is_stable=true"); + } + return attrs; } bool HloSortInstruction::IdenticalSlowPath( @@ -631,15 +697,20 @@ bool HloSortInstruction::IdenticalSlowPath( const std::function& eq_computations) const { const auto& casted_other = static_cast(other); - return dimensions() == casted_other.dimensions(); + if (dimensions() != casted_other.dimensions()) { + return false; + } + if (is_stable() != casted_other.is_stable()) { + return false; + } + return eq_computations(to_apply(), other.to_apply()); } std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - HloInstruction* keys = new_operands[0]; - return absl::make_unique(shape, dimensions(0), keys, - new_operands.subspan(1)); + return absl::make_unique( + shape, dimensions(0), new_operands, to_apply(), is_stable()); } HloTransposeInstruction::HloTransposeInstruction( @@ -735,7 +806,7 @@ HloMapInstruction::HloMapInstruction(const Shape& shape, AppendComputation(map_computation); // TODO(b/65689298) Remove code below once Map is generalized to accept // arbitrary map dimensions. - dimensions_.resize(ShapeUtil::Rank(shape)); + dimensions_.resize(shape.rank()); std::iota(dimensions_.begin(), dimensions_.end(), 0); } @@ -815,8 +886,7 @@ std::vector HloSliceInstruction::ExtraAttributesToStringImpl( std::vector bounds; bounds.reserve(slice_starts_.size()); const bool omit_stride = - std::all_of(slice_strides_.begin(), slice_strides_.end(), - [](int64 stride) { return stride == 1; }); + absl::c_all_of(slice_strides_, [](int64 stride) { return stride == 1; }); for (int i = 0; i < slice_starts_.size(); ++i) { string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); bounds.push_back( @@ -867,7 +937,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index) { Shape* mutable_array_subshape = ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); - CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); + CHECK(mutable_array_subshape->IsArray()); // Normally array_subshape will always have a layout, but this invariant is // temporarily broken in LayoutAssignment::AssignLayouts. @@ -901,11 +971,11 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( string operands; // For constants, show the actual value in place of an empty operand list. if (literal_.has_value() && - ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || + ((shape().IsArray() && ShapeUtil::ElementsIn(shape()) <= 10) || options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. - string tmp = literal().ToString(); + string tmp = literal().ToStringWithoutShape(); std::replace(tmp.begin(), tmp.end(), '\n', ' '); std::vector v = absl::StrSplit(tmp, ' '); bool first = true; @@ -1052,8 +1122,7 @@ HloInstruction* HloFusionInstruction::AddFusionOperand( void HloFusionInstruction::MergeFusionInstruction( HloFusionInstruction* instruction_to_merge) { - CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != - operands().end()); + CHECK(absl::c_linear_search(operands(), instruction_to_merge)); // Clone the instruction from which to merge fused instructions. std::unique_ptr cloned = instruction_to_merge->Clone(); HloFusionInstruction* cloned_fusion = @@ -1220,8 +1289,8 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( // corresponding fused parameter instruction. Renumber parameters as // necessary to make parameter numbers consistent with their index in the // fused_parameter_ vector. - bool in_operand_list = std::find(operands().begin(), operands().end(), - instruction_to_fuse) != operands().end(); + bool in_operand_list = + absl::c_linear_search(operands(), instruction_to_fuse); CHECK(add_output || in_operand_list); if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { // We assume all uses of a kTuple operation are GTE ops, not another @@ -1325,7 +1394,7 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( if (newly_created_tuple_instr) { HloInstruction* new_instr = parent()->AddInstruction( HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0)); - TF_CHECK_OK(ReplaceAllUsesWith(new_instr)); + TF_CHECK_OK(ReplaceAllUsesWithDifferentShape(new_instr)); } int64 index = tuple_elements.size(); if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { @@ -1372,8 +1441,14 @@ bool HloFusionInstruction::IdenticalSlowPath( other.fused_instructions_computation()); } +static uint64 HashOperandRecursive(const HloInstruction* hlo) { + return hlo->Hash(HashOperandRecursive); +} + uint64 HloFusionInstruction::InnerHash() const { - return fused_instructions_computation()->Hash(); + // Use HashOperandRecursive to recursively compute hash on inner operands. + return fused_instructions_computation()->root_instruction()->Hash( + HashOperandRecursive); } std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( @@ -1463,9 +1538,30 @@ HloParameterInstruction::HloParameterInstruction(int64 parameter_number, HloInstructionProto HloParameterInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_parameter_number(parameter_number_); + if (parameter_replicated_at_leaf_buffers_) { + for (bool replicated : *parameter_replicated_at_leaf_buffers_) { + proto.mutable_parameter_replication()->add_replicated_at_leaf_buffers( + replicated); + } + } return proto; } +std::vector HloParameterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result; + if (!parameter_replicated_at_leaf_buffers_) { + return result; + } + std::vector buffers_replicated_strs; + for (bool replicated : *parameter_replicated_at_leaf_buffers_) { + buffers_replicated_strs.push_back(replicated ? "true" : "false"); + } + result.push_back(StrCat("parameter_replication={", + StrJoin(buffers_replicated_strs, ","), "}")); + return result; +} + string HloParameterInstruction::OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { @@ -1649,11 +1745,12 @@ std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( HloConvolutionInstruction::HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - int64 feature_group_count, const Window& window, + int64 feature_group_count, int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config) : HloInstruction(HloOpcode::kConvolution, shape), feature_group_count_(feature_group_count), + batch_group_count_(batch_group_count), window_(window), convolution_dimension_numbers_(dimension_numbers), precision_config_(precision_config) { @@ -1684,6 +1781,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; proto.set_feature_group_count(feature_group_count_); + proto.set_batch_group_count(batch_group_count_); *proto.mutable_precision_config() = precision_config_; return proto; } @@ -1700,6 +1798,10 @@ std::vector HloConvolutionInstruction::ExtraAttributesToStringImpl( extra.push_back(StrCat("feature_group_count=", feature_group_count_)); } + if (batch_group_count_ != 1) { + extra.push_back(StrCat("batch_group_count=", batch_group_count_)); + } + string precision_config_string = PrecisionConfigToString(precision_config_); if (!precision_config_string.empty()) { extra.push_back(precision_config_string); @@ -1717,6 +1819,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath( if (feature_group_count_ != other.feature_group_count()) { return false; } + if (batch_group_count_ != other.batch_group_count()) { + return false; + } return protobuf_util::ProtobufEquals(window(), casted_other.window()) && protobuf_util::ProtobufEquals( convolution_dimension_numbers(), @@ -1731,8 +1836,9 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( - shape, new_operands[0], new_operands[1], feature_group_count_, window(), - convolution_dimension_numbers_, precision_config_); + shape, new_operands[0], new_operands[1], feature_group_count_, + batch_group_count_, window(), convolution_dimension_numbers_, + precision_config_); } HloReduceWindowInstruction::HloReduceWindowInstruction( @@ -1834,6 +1940,7 @@ HloCustomCallInstruction::HloCustomCallInstruction( custom_call_target_(custom_call_target.begin(), custom_call_target.end()), opaque_(opaque.begin(), opaque.end()), feature_group_count_(1), + batch_group_count_(1), layout_constrained_(false) { for (auto operand : operands) { AppendOperand(operand); @@ -1848,6 +1955,7 @@ HloCustomCallInstruction::HloCustomCallInstruction( custom_call_target_(custom_call_target.begin(), custom_call_target.end()), opaque_(opaque.begin(), opaque.end()), feature_group_count_(1), + batch_group_count_(1), layout_constrained_(true), operand_shapes_with_layout_(operand_shapes_with_layout.begin(), operand_shapes_with_layout.end()) { @@ -1868,6 +1976,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { proto.set_custom_call_target(custom_call_target_); proto.set_custom_call_opaque(opaque_); proto.set_feature_group_count(feature_group_count_); + proto.set_batch_group_count(batch_group_count_); if (layout_constrained()) { proto.set_constrain_layout(true); for (const Shape& shape : operand_shapes_with_layout_) { @@ -1891,6 +2000,9 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( if (feature_group_count_ != 1) { extra.push_back(StrCat("feature_group_count=", feature_group_count_)); } + if (batch_group_count_ != 1) { + extra.push_back(StrCat("batch_group_count=", batch_group_count_)); + } // By contract, we print the custom call target even if // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. @@ -1934,6 +2046,20 @@ bool HloCustomCallInstruction::IdenticalSlowPath( if (feature_group_count_ != casted_other.feature_group_count_) { return false; } + if (batch_group_count_ != casted_other.batch_group_count_) { + return false; + } + if (layout_constrained() != casted_other.layout_constrained()) { + return false; + } + if (layout_constrained()) { + for (int64 i = 0; i < operand_shapes_with_layout_.size(); ++i) { + if (!ShapeUtil::Equal(operand_shapes_with_layout_[i], + casted_other.operand_shapes_with_layout_[i])) { + return false; + } + } + } return custom_call_target_ == casted_other.custom_call_target_ && opaque_ == casted_other.opaque_; } @@ -1944,6 +2070,10 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { auto cloned = absl::make_unique( shape, new_operands, custom_call_target(), opaque()); + if (layout_constrained()) { + cloned->layout_constrained_ = true; + cloned->operand_shapes_with_layout_ = operand_shapes_with_layout(); + } if (window_ != nullptr) { cloned->set_window(*window_); } @@ -1951,6 +2081,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_); } cloned->set_feature_group_count(feature_group_count_); + cloned->set_batch_group_count(batch_group_count_); return std::move(cloned); } @@ -1994,12 +2125,44 @@ std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( HloDynamicSliceInstruction::HloDynamicSliceInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, absl::Span slice_sizes) - : HloInstruction(HloOpcode::kDynamicSlice, shape), + : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape), dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { AppendOperand(operand); AppendOperand(start_indices); } +HloDynamicSliceInstruction::HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, + absl::Span start_indices, + absl::Span slice_sizes) + : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape), + dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { + AppendOperand(operand); + for (HloInstruction* index : start_indices) { + AppendOperand(index); + } +} + +HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + HloInstruction* start_indices) + : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) { + AppendOperand(operand); + AppendOperand(update); + AppendOperand(start_indices); +} + +HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + absl::Span start_indices) + : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) { + AppendOperand(operand); + AppendOperand(update); + for (HloInstruction* index : start_indices) { + AppendOperand(index); + } +} + HloInstructionProto HloDynamicSliceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); for (int64 slice_size : dynamic_slice_sizes_) { @@ -2025,9 +2188,14 @@ std::unique_ptr HloDynamicSliceInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 2); - return absl::make_unique( - shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); + if (new_operands.size() == 2 && new_operands[1]->shape().rank() == 1) { + // TODO(b/118437727): Old form, remove this path. + return absl::make_unique( + shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); + } else { + return absl::make_unique( + shape, new_operands[0], new_operands.subspan(1), dynamic_slice_sizes_); + } } HloGatherInstruction::HloGatherInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index b5c28137a145667a977d39c9d3c40c6d36a8436e..4d23cb671f24623f56faa9b69015cef21752a799 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -131,6 +131,34 @@ class HloFftInstruction : public HloInstruction { std::vector fft_length_; }; +class HloTriangularSolveInstruction : public HloInstruction { + public: + explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a, + HloInstruction* b, + const TriangularSolveOptions& options); + const TriangularSolveOptions& triangular_solve_options() const { + return triangular_solve_options_; + } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + TriangularSolveOptions triangular_solve_options_; +}; + class HloSendRecvInstruction : public HloInstruction { public: // Returns the channel id associated with the instruction. The id is @@ -242,14 +270,10 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id); - // Returns the barrier config used for the CrossReplicaSum implementation of + // Returns the barrier config used for the AllReduce implementation of // each backend. - string cross_replica_sum_barrier() const { - return cross_replica_sum_barrier_; - } - void set_cross_replica_sum_barrier(string barrier) { - cross_replica_sum_barrier_ = barrier; - } + string all_reduce_barrier() const { return all_reduce_barrier_; } + void set_all_reduce_barrier(string barrier) { all_reduce_barrier_ = barrier; } absl::optional all_reduce_id() const { return all_reduce_id_; } void set_all_reduce_id(const absl::optional& all_reduce_id); @@ -257,6 +281,10 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns true if the AllReduce does no communication, so it's equivalent + // to a mem copy. + bool IsNoop() const; + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -270,8 +298,8 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // The string representation of the barrier config used for CrossReplicaSum. - string cross_replica_sum_barrier_; + // The string representation of the barrier config used for AllReduce. + string all_reduce_barrier_; // For Allreduce nodes from different modules, if they have the same // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be @@ -418,8 +446,8 @@ class HloReduceInstruction : public HloInstruction { class HloSortInstruction : public HloInstruction { public: explicit HloSortInstruction(const Shape& shape, int64 dimension, - HloInstruction* keys, - absl::Span values = {}); + absl::Span operands, + HloComputation* compare, bool is_stable); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -432,6 +460,7 @@ class HloSortInstruction : public HloInstruction { HloInstruction* mutable_keys() { return mutable_operand(0); } // Returns the number of value operands. int64 values_count() const { return operand_count() - 1; } + bool is_stable() const { return is_stable_; } private: std::vector ExtraAttributesToStringImpl( @@ -446,6 +475,7 @@ class HloSortInstruction : public HloInstruction { HloCloneContext* context) const override; std::vector dimensions_; + bool is_stable_; }; class HloTransposeInstruction : public HloInstruction { @@ -787,10 +817,28 @@ class HloParameterInstruction : public HloInstruction { explicit HloParameterInstruction(int64 parameter_number, const Shape& shape, const string& name); int64 parameter_number() const { return parameter_number_; } + + // Sets and gets the whether all replicas will receive the same parameter data + // for each leaf buffer in data parallelism. + void set_parameter_replicated_at_leaf_buffers( + absl::Span parameter_replicated_at_leaf_buffers) { + CHECK_EQ(ShapeUtil::GetLeafCount(shape()), + parameter_replicated_at_leaf_buffers.size()); + parameter_replicated_at_leaf_buffers_.emplace( + parameter_replicated_at_leaf_buffers.begin(), + parameter_replicated_at_leaf_buffers.end()); + } + const absl::optional>& + parameter_replicated_at_leaf_buffers() const { + return parameter_replicated_at_leaf_buffers_; + } + // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; bool IdenticalSlowPath( const HloInstruction& other, const std::function& @@ -804,6 +852,10 @@ class HloParameterInstruction : public HloInstruction { HloCloneContext* context) const override; int64 parameter_number_ = 0; + + // Specifies whether each buffer has the same parameter value on all replicas + // in data parallelism. + absl::optional> parameter_replicated_at_leaf_buffers_; }; class HloGetTupleElementInstruction : public HloInstruction { @@ -903,9 +955,7 @@ class HloOutfeedInstruction : public HloInstruction { HloInstruction* token_operand, absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. - const Shape& outfeed_shape() const { - return outfeed_shape_; - } + const Shape& outfeed_shape() const { return outfeed_shape_; } // Returns the config for the Outfeed instruction. const string& outfeed_config() const { return outfeed_config_; } // Returns a serialized representation of this instruction. @@ -933,7 +983,7 @@ class HloConvolutionInstruction : public HloInstruction { public: explicit HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - int64 feature_group_count, const Window& window, + int64 feature_group_count, int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config); const Window& window() const override { return window_; } @@ -949,6 +999,10 @@ class HloConvolutionInstruction : public HloInstruction { // dimension and output feature dimension. int64 feature_group_count() const { return feature_group_count_; } + // The number of feature groups. Must be a divisor of the input batch + // dimension. + int64 batch_group_count() const { return batch_group_count_; } + // Returns the information used to tell the implementation information about // what sort of precision is requested. The meaning of the field is backend // specific. At the moment, it is only supported for kConvolution and kDot. @@ -977,6 +1031,9 @@ class HloConvolutionInstruction : public HloInstruction { // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count_; + // The number of feature groups. Must be a divisor of the input batch + // dimension. + int64 batch_group_count_; // Describes the window used for a convolution. Window window_; // Describes the dimension numbers used for a convolution. @@ -1099,7 +1156,11 @@ class HloCustomCallInstruction : public HloInstruction { void set_feature_group_count(int64 feature_group_count) { feature_group_count_ = feature_group_count; } + void set_batch_group_count(int64 batch_group_count) { + batch_group_count_ = batch_group_count; + } int64 feature_group_count() const { return feature_group_count_; } + int64 batch_group_count() const { return batch_group_count_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1134,6 +1195,7 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr convolution_dimension_numbers_; // The number of feature groups. This is used for grouped convolutions. int64 feature_group_count_; + int64 batch_group_count_; // Whether the result and operand layouts are constrained. bool layout_constrained_; // For layout-constrained custom calls, this vector holds the shape with @@ -1171,12 +1233,38 @@ class HloPadInstruction : public HloInstruction { PaddingConfig padding_config_; }; -class HloDynamicSliceInstruction : public HloInstruction { +class HloDynamicIndexInstruction : public HloInstruction { + public: + explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape) + : HloInstruction(opcode, shape) {} + virtual int64 first_index_operand_number() const = 0; + + // Returns a subspan of operands which represent the start indices. + absl::Span index_operands() const { + return absl::MakeSpan(operands()).subspan(first_index_operand_number()); + } + + // Returns the shapes of the index operands. + std::vector index_shapes() const { + std::vector shapes; + auto indices = index_operands(); + for (const HloInstruction* index : indices) { + shapes.push_back(index->shape()); + } + return shapes; + } +}; + +class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { public: explicit HloDynamicSliceInstruction(const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, absl::Span slice_sizes); + explicit HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, + absl::Span start_indices, + absl::Span slice_sizes); // Old methods kept for smooth subclassing transition END. // Returns the size of the slice in the given dimension for a dynamic // slice node. @@ -1189,6 +1277,8 @@ class HloDynamicSliceInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + int64 first_index_operand_number() const override { return 1; } + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -1206,6 +1296,19 @@ class HloDynamicSliceInstruction : public HloInstruction { std::vector dynamic_slice_sizes_; }; +class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction { + public: + explicit HloDynamicUpdateSliceInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices); + explicit HloDynamicUpdateSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + absl::Span start_indices); + + int64 first_index_operand_number() const override { return 2; } +}; + class HloGatherInstruction : public HloInstruction { public: explicit HloGatherInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 1390537101e95a08e4ba4eef7ae8d6059a40e916..2255383322873a39c7076e0f4f0dd541bc79014d 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -17,8 +17,10 @@ limitations under the License. #include +#include "absl/strings/ascii.h" #include "absl/strings/escaping.h" #include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -36,8 +38,8 @@ constexpr int kError = -2; // [a-zA-Z0-9_.-] bool IsIdentifierChar(char c) { - return isalnum(static_cast(c)) || c == '-' || c == '.' || - c == '_'; + return absl::ascii_isalnum(static_cast(c)) || c == '-' || + c == '.' || c == '_'; } } // namespace @@ -82,15 +84,29 @@ tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( return tensorflow::RegexpStringPiece(begin, end - begin); } +TokKind HloLexer::LookAhead() { + if (GetKind() == TokKind::kEof || GetKind() == TokKind::kError) { + return GetKind(); + } + + const char* old_current_ptr = current_ptr_; + TokenState old_token_state = token_state_; + Lex(); + TokKind kind = GetKind(); + token_state_ = old_token_state; + current_ptr_ = old_current_ptr; + return kind; +} + TokKind HloLexer::LexToken() { while (true) { - token_start_ = current_ptr_; + token_state_.token_start = current_ptr_; int current_char = GetNextChar(); switch (current_char) { default: // [a-zA-Z_] - if (isalpha(static_cast(current_char)) || + if (absl::ascii_isalpha(static_cast(current_char)) || current_char == '_') { return LexIdentifier(); } @@ -125,12 +141,20 @@ TokKind HloLexer::LexToken() { return LexNumberOrPattern(); case '=': return TokKind::kEqual; + case '<': + if (current_char == '<' && PeekCurrentChar() == '=') { + current_ptr_++; + return TokKind::kLeq; + } + return TokKind::kError; case ',': return TokKind::kComma; case '%': return LexPercent(); case ':': return TokKind::kColon; + case '*': + return TokKind::kAsterisk; case '[': return TokKind::kLsquare; case ']': @@ -190,6 +214,15 @@ TokKind HloLexer::LexToken() { // A lone '/' is an error. return TokKind::kError; } + case '.': + if (PeekCurrentChar() == '.') { + current_ptr_++; + if (PeekCurrentChar() == '.') { + current_ptr_++; + return TokKind::kDots; + } + } + return TokKind::kError; case '"': return LexString(); } @@ -206,43 +239,37 @@ TokKind HloLexer::LexToken() { // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} // identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexIdentifier() { - { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); - // 'consumable' will be advanced iff its prefix matches the pattern. - static LazyRE2 shape_pattern = { - R"(^(\w*\d*)\[([\d,\s]*)\](?:(dense|sparse)?{([\d,\s]+)})?)"}; - if (RE2::Consume(&consumable, *shape_pattern)) { - auto status_or_shape = ShapeUtil::ParseShapeString( - StringPieceFromPointers(token_start_, consumable.begin())); - if (status_or_shape.ok()) { - // This is a shape string. - shape_val_ = status_or_shape.ValueOrDie(); - current_ptr_ = consumable.begin(); - return TokKind::kShape; - } - } - } - while (IsIdentifierChar(PeekCurrentChar())) { current_ptr_++; } // If followed by ':', it's a name. if (PeekCurrentChar() == ':') { - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); current_ptr_++; // skip ':' return TokKind::kName; } // If followed by '=', it's a attribute name. if (PeekCurrentChar() == '=') { - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); current_ptr_++; // skip '=' return TokKind::kAttributeName; } absl::string_view identifier = - StringPieceFromPointers(token_start_, current_ptr_); + StringPieceFromPointers(token_state_.token_start, current_ptr_); + + // Primitive type strings are reserved words. The exception is 'tuple' whose + // type is represented using nested parentheses without the string 'tuple'. + if (primitive_util::IsPrimitiveTypeName(identifier)) { + PrimitiveType primitive_type = + primitive_util::StringToPrimitiveType(identifier).ValueOrDie(); + if (primitive_type != TUPLE) { + token_state_.primitive_type_val = primitive_type; + return TokKind::kPrimitiveType; + } + } // See if this is a keyword. #define KEYWORD(STR) \ @@ -261,21 +288,23 @@ TokKind HloLexer::LexIdentifier() { KEYWORD(ROOT); KEYWORD(maximal); KEYWORD(replicated); + KEYWORD(sparse); #undef KEYWORD { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 dim_labels_pattern = { R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"}; if (RE2::Consume(&consumable, *dim_labels_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDimLabels; } } - str_val_ = string(identifier); + token_state_.str_val = string(identifier); return TokKind::kIdent; } @@ -283,13 +312,13 @@ TokKind HloLexer::LexIdentifier() { // name ::= [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexPercent() { const char* name_start = current_ptr_; - if (isalpha(static_cast(PeekCurrentChar())) || + if (absl::ascii_isalpha(static_cast(PeekCurrentChar())) || PeekCurrentChar() == '_') { current_ptr_++; while (IsIdentifierChar(PeekCurrentChar())) { current_ptr_++; } - str_val_.assign(name_start, current_ptr_); + token_state_.str_val.assign(name_start, current_ptr_); return TokKind::kName; } return TokKind::kError; @@ -307,12 +336,14 @@ TokKind HloLexer::LexPercent() { // int ::= [-]?[0-9]+ // negative inf ::= '-inf' TokKind HloLexer::LexNumberOrPattern() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 float_pattern = { R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); - CHECK(absl::SimpleAtod(string(token_start_, current_ptr_), &decimal_val_)); + CHECK(absl::SimpleAtod(string(token_state_.token_start, current_ptr_), + &token_state_.decimal_val)); return TokKind::kDecimal; } @@ -324,27 +355,28 @@ TokKind HloLexer::LexNumberOrPattern() { if (RE2::Consume(&consumable, *dim_labels_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDimLabels; } if (RE2::Consume(&consumable, *dxd_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDxD; } if (RE2::Consume(&consumable, *pad_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kPad; } static LazyRE2 int_pattern = {R"([-]?\d+)"}; if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); - auto slice = StringPieceFromPointers(token_start_, current_ptr_); - if (absl::SimpleAtoi(slice, &int64_val_)) { + auto slice = + StringPieceFromPointers(token_state_.token_start, current_ptr_); + if (absl::SimpleAtoi(slice, &token_state_.int64_val)) { return TokKind::kInt; } LOG(ERROR) << "Failed to parse int literal: " << slice; @@ -403,16 +435,17 @@ absl::string_view HloLexer::GetLine(LocTy loc) const { } // Lexes quoted string with escaping characters. If matched, the quoted string -// will be unescaped and stored to str_val_. +// will be unescaped and stored to token_state_.str_val. TokKind HloLexer::LexString() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); absl::string_view raw = - StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); + StringPieceFromPointers(token_state_.token_start + 1, current_ptr_ - 1); string error; - if (!absl::CUnescape(raw, &str_val_, &error)) { + if (!absl::CUnescape(raw, &token_state_.str_val, &error)) { LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; return TokKind::kError; } @@ -433,6 +466,8 @@ string TokKindToString(TokKind kind) { return "kComma"; case TokKind::kColon: return "kColon"; + case TokKind::kAsterisk: + return "kAsterisk"; case TokKind::kLsquare: return "kLsquare"; case TokKind::kRsquare: @@ -447,6 +482,8 @@ string TokKindToString(TokKind kind) { return "kRparen"; case TokKind::kArrow: return "kArrow"; + case TokKind::kLeq: + return "kLeq"; case TokKind::kw_HloModule: return "kw_HloModule"; case TokKind::kw_ENTRY: @@ -467,6 +504,10 @@ string TokKindToString(TokKind kind) { return "kw_inf"; case TokKind::kNegInf: return "kNegInf"; + case TokKind::kw_sparse: + return "kw_sparse"; + case TokKind::kPrimitiveType: + return "kPrimitiveType"; case TokKind::kName: return "kName"; case TokKind::kAttributeName: @@ -481,12 +522,12 @@ string TokKindToString(TokKind kind) { return "kIdent"; case TokKind::kString: return "kString"; - case TokKind::kShape: - return "kShape"; case TokKind::kInt: return "kInt"; case TokKind::kDecimal: return "kDecimal"; + case TokKind::kDots: + return "kDots"; } } diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index d6a2b292a3916b2ff85f278cf5cb9f1567df88fa..383fb4e862b8e32771879d055e663dc821a5c839 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -29,6 +28,60 @@ limitations under the License. namespace xla { +// Defines different kinds of tokens used by the HLO lexer. +// +// You shouldn't need to use this directly unless you're using HloLexer +// directly, and you probably don't need to do that. Use hlo_parser instead. +enum class TokKind { + // Markers + kEof, + kError, + + // Tokens with no info. + kEqual, // = + kComma, // , + kColon, // : + kAsterisk, // * + kLsquare, + kRsquare, // [ ] + kLbrace, + kRbrace, // { } + kLparen, + kRparen, // ( ) + kDots, // ... + + kArrow, // -> + kLeq, // <= + + // Keywords + kw_HloModule, + kw_ENTRY, + kw_ROOT, + kw_true, + kw_false, + kw_maximal, + kw_replicated, + kw_nan, + kw_inf, + kw_sparse, + + kNegInf, // -inf + + // Typed tokens. + kPrimitiveType, // F32, PRED, etc. + kName, // %foo + kAttributeName, // dimensions= + kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} + kDxD, // [0-9]+(x[0-9]+)+ + kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kIdent, // other identifiers + kString, // "abcd\"\n" + kInt, // 42 + kDecimal, // 4.2 +}; + +string TokKindToString(TokKind kind); + // Lexer for the HloModule::ToString() format text. // // This class is meant to be used by hlo_parser.cc. You shouldn't need to use @@ -39,9 +92,9 @@ class HloLexer { current_ptr_ = buf_.begin(); } - TokKind Lex() { return current_kind_ = LexToken(); } + TokKind Lex() { return token_state_.current_kind = LexToken(); } - TokKind GetKind() const { return current_kind_; } + TokKind GetKind() const { return token_state_.current_kind; } string GetStrVal() const { switch (GetKind()) { case TokKind::kName: @@ -51,28 +104,28 @@ class HloLexer { case TokKind::kPad: case TokKind::kString: case TokKind::kIdent: - return str_val_; + return token_state_.str_val; default: LOG(FATAL) << "This token does not have string value"; } } - Shape GetShapeVal() const { - CHECK(GetKind() == TokKind::kShape); - return shape_val_; - } - tensorflow::int64 GetInt64Val() const { + int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); - return int64_val_; + return token_state_.int64_val; } double GetDecimalVal() const { CHECK(GetKind() == TokKind::kDecimal); - return decimal_val_; + return token_state_.decimal_val; + } + PrimitiveType GetPrimitiveTypeVal() const { + CHECK(GetKind() == TokKind::kPrimitiveType); + return token_state_.primitive_type_val; } typedef const char* LocTy; // Returns the location of the current token. - LocTy GetLoc() const { return token_start_; } + LocTy GetLoc() const { return token_state_.token_start; } // Returns the line and column of a location in the buffer. std::pair GetLineAndColumn(LocTy location) const; @@ -80,6 +133,9 @@ class HloLexer { // Returns the whole line given the location. absl::string_view GetLine(LocTy loc) const; + // Looks ahead one token and returns it. Lexer state is unchanged. + TokKind LookAhead(); + private: // Returns the current character. If it's neither the end of input buffer nor // an invalid character, moves the pointer forward. @@ -112,12 +168,15 @@ class HloLexer { const char* current_ptr_; // Information about the current token. - const char* token_start_ = nullptr; - TokKind current_kind_; - string str_val_; - Shape shape_val_; - tensorflow::int64 int64_val_; - double decimal_val_; + struct TokenState { + const char* token_start = nullptr; + TokKind current_kind; + string str_val; + int64 int64_val; + double decimal_val; + PrimitiveType primitive_type_val; + }; + TokenState token_state_; struct LineNoCacheTy { const char* last_query; diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 5bf055f3c012fef687cdc275d62efdf2d4cd5e5c..e14bcfa7f67e736a4d04f5b236fb2df02cf150e0 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" @@ -36,11 +37,11 @@ namespace xla { namespace { using Worklist = std::deque; -using Workset = std::unordered_set; +using Workset = absl::flat_hash_set; void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, Workset* workset) { - if (workset->count(instruction) == 0) { + if (!workset->contains(instruction)) { worklist->push_back(instruction); workset->insert(instruction); VLOG(3) << "ADD instruction: " << instruction->name(); diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index e0ae1173c6114f0bc6ef18b2cfff9d54ccfe2faf..436cccb1fb9ecf6f4efad772c700c611b28ce628 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -403,9 +403,9 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { HloModule OutfeedLoop WhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) @@ -436,9 +436,9 @@ TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) { HloModule OutfeedLoop InnerWhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 235efb19ce4ed28a5cd9fe5ca52ae5d8e9e5ba3d..67488a6a9a0c9cba7f576f9036c3a0cbe1900fff 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -178,7 +178,7 @@ HLO_MATCHER(Constant); HLO_MATCHER(Convert); HLO_MATCHER(Convolution); HLO_MATCHER(Copy); -HLO_MATCHER(CrossReplicaSum); +HLO_MATCHER(AllReduce); HLO_MATCHER(CollectivePermute); HLO_MATCHER(Divide); HLO_MATCHER(Domain); @@ -312,8 +312,8 @@ inline ::testing::Matcher Shape( } inline ::testing::Matcher Shape( absl::string_view shape) { - return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher( - ShapeUtil::ParseShapeString(shape).ValueOrDie())); + return ::testing::MakeMatcher( + new ::xla::testing::HloShapeMatcher(ParseShape(shape).ValueOrDie())); } inline ::testing::Matcher ShapeWithLayout( const class Shape& shape) { @@ -323,7 +323,7 @@ inline ::testing::Matcher ShapeWithLayout( inline ::testing::Matcher ShapeWithLayout( absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher( - ShapeUtil::ParseShapeString(shape).ValueOrDie())); + ParseShape(shape).ValueOrDie())); } // Verifies the value of the HloSharing against the provided sharding object. diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index 7227bfb27c74758d2b79e404afc9eb97a1ca894d..76cc29cbb7848eb424d07abf11a95ffd59e9eed6 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -118,7 +118,7 @@ class HloTrivialScheduler : public HloModulePass { }; // A trivial pass which clears the schedule currently set on the -// HloModule. After this pass runs HloModudle::has_schedule will return false. +// HloModule. After this pass runs HloModule::has_schedule will return false. class HloDescheduler : public HloModulePass { public: HloDescheduler() = default; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index fe8371384c0fa3900a9022f101ff0b296439cf16..8322870cfd6a89fc6f863da8fd4a3576e8845cd7 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -107,11 +107,10 @@ HloComputation* HloModule::AddEntryComputation( } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { - auto it = - std::find_if(computations_.begin(), computations_.end(), - [&to_remove](const std::unique_ptr& comp) { - return comp.get() == to_remove; - }); + auto it = absl::c_find_if( + computations_, [&to_remove](const std::unique_ptr& comp) { + return comp.get() == to_remove; + }); TF_RET_CHECK(it->get() == to_remove); computations_.erase(it); return Status::OK(); @@ -247,11 +246,39 @@ HloModuleProto HloModule::ToProto() const { return proto; } +Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const { + absl::flat_hash_set computation_names; + absl::flat_hash_set computation_ids; + absl::flat_hash_set instruction_names; + absl::flat_hash_set instruction_ids; + + for (const HloComputation* computation : computations()) { + TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) + << "Computation name is not unique: " << computation->name(); + computation_names.insert(computation->name()); + + TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id())) + << "Computation id is not unique: " << computation->unique_id(); + computation_ids.insert(computation->unique_id()); + + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) + << "Instruction name is not unique: " << instruction->name(); + instruction_names.insert(instruction->name()); + + TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id())) + << "Instruction id is not unique: " << instruction->unique_id(); + instruction_ids.insert(instruction->unique_id()); + } + } + return Status::OK(); +} + /* static */ StatusOr> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { VLOG(2) << "CreateFromProto()"; - XLA_VLOG_LINES(2, proto.DebugString()); + XLA_VLOG_LINES(3, proto.DebugString()); // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. @@ -304,11 +331,10 @@ StatusOr> HloModule::CreateFromProto( auto module = absl::make_unique(proto.name(), module_config); // Sort the computations in the proto id's order. - std::sort(computations.begin(), computations.end(), - [&](const std::unique_ptr& a, - const std::unique_ptr& b) { - return to_proto_id[a.get()] < to_proto_id[b.get()]; - }); + absl::c_sort(computations, [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); // Add sorted computations to the module. for (auto& computation : computations) { @@ -331,28 +357,8 @@ StatusOr> HloModule::CreateFromProto( DynamicParameterBinding::CreateFromProto( proto.dynamic_parameter_binding())); - absl::flat_hash_set computation_names; - absl::flat_hash_set instruction_names; - absl::flat_hash_set computation_ids; - absl::flat_hash_set instruction_ids; - for (HloComputation* computation : module->computations()) { - TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) - << "Computation name is not unique: " << computation->name(); - computation_names.insert(computation->name()); - - TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id())) - << "Computation id is not unique: " << computation->unique_id(); - computation_ids.insert(computation->unique_id()); - for (HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) - << "Instruction name is not unique: " << instruction->name(); - instruction_names.insert(instruction->name()); - - TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id())) - << "Instruction id is not unique: " << instruction->unique_id(); - instruction_ids.insert(instruction->unique_id()); - } - } + TF_RETURN_IF_ERROR( + module->CheckUniqueNamesAndIdsForComputationsAndInstructions()); if (proto.has_schedule()) { TF_ASSIGN_OR_RETURN( @@ -392,15 +398,12 @@ namespace { // Returns whether `hlo` is used outside the given subcomputation. // `instructions_in_subcomputation` is the instruction set of the given // subcomputation. -bool IsUsedOutsideSubcomputation( - const HloInstruction& hlo, - const std::unordered_set& instructions_in_subcomputation) { - for (HloInstruction* user : hlo.users()) { - if (!instructions_in_subcomputation.count(user)) { - return true; - } - } - return false; +bool IsUsedOutsideSubcomputation(const HloInstruction& hlo, + const absl::flat_hash_set& + instructions_in_subcomputation) { + return absl::c_any_of(hlo.users(), [&](HloInstruction* user) { + return !instructions_in_subcomputation.contains(user); + }); } } // anonymous namespace @@ -411,9 +414,9 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( // A map from original instructions to their counterparts in the new outlined // function. - std::unordered_map outlined_instructions; + absl::flat_hash_map outlined_instructions; // A set that contains all instructions to be outlined. - std::unordered_set instruction_set_to_outline( + absl::flat_hash_set instruction_set_to_outline( instructions_to_outline.begin(), instructions_to_outline.end()); std::vector arguments; std::vector outputs; @@ -502,7 +505,7 @@ std::vector HloModule::MakeComputationPostOrder() const { // First determine all root computations by building a set of nonroot // computations (computations which are called by an instruction in the // module). - std::set nonroot_computations; + absl::flat_hash_set nonroot_computations; for (auto& computation : computations_) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -515,19 +518,19 @@ std::vector HloModule::MakeComputationPostOrder() const { // Keep track of computations which have already been added to the post // order. This prevents duplication as an embedded computation may be called // from two different root computations. - std::set added_computations; + absl::flat_hash_set added_computations; std::vector post_order; for (auto& computation : computations_) { - if (nonroot_computations.count(computation.get()) == 0) { + if (!nonroot_computations.contains(computation.get())) { for (HloComputation* embedded_computation : computation->MakeEmbeddedComputationsList()) { - if (added_computations.count(embedded_computation) == 0) { + if (!added_computations.contains(embedded_computation)) { post_order.push_back(embedded_computation); added_computations.insert(embedded_computation); } } // Root computations should only be encountered once. - CHECK_EQ(0, added_computations.count(computation.get())); + CHECK(!added_computations.contains(computation.get())); post_order.push_back(computation.get()); added_computations.insert(computation.get()); } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 7b9cbf9a53a2201b1312405bbd7ed2b88f65c9be..b6fe6a5cdbd0934014f1152acd48c7a5973bead3 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -136,7 +136,9 @@ class HloModule { // information on opcode, shape, operands, and typically a root instruction. // This function returns the same hash value for equivalent HLO modules, // with respect to HloInstruction::Identical() method. - uint64 Hash() const { return entry_computation()->Hash(); } + uint64 Hash() const { + return entry_computation()->root_instruction()->Hash(); + } // Gets the computations in this module. // @@ -185,6 +187,7 @@ class HloModule { std::vector MakeNonfusionComputations() const; const HloModuleConfig& config() const { return config_; } + void set_config(HloModuleConfig& config) { config_ = config; } // Return a string representation of the module. // @@ -262,6 +265,18 @@ class HloModule { const HloSchedule& schedule() const { return *schedule_; } HloSchedule& schedule() { return *schedule_; } + HloComputation* AddComputationAndUnifyNamesAndIds( + std::unique_ptr computation, bool is_entry) { + computation->ClearUniqueIdInternal(); + for (auto* instruction : computation->instructions()) { + instruction->ClearUniqueIdInternal(); + } + return AddComputationInternal(std::move(computation), is_entry, + /*uniquify_identifiers=*/true); + } + + Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const; + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc index 31d26cc51e8217234526bbfeb83510aadf2c27b5..6b72ba128664d27c51aa8dcfa61fe959a0160c73 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -49,7 +49,7 @@ StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { auto* while_body_param = while_body_comp->parameter_instruction(0); auto* while_body_root = while_body_comp->root_instruction(); - if (!ShapeUtil::IsTuple(xla_while->shape()) || + if (!xla_while->shape().IsTuple() || while_body_root->opcode() != HloOpcode::kTuple) { // Only run DCE on tuple-shaped while loops where body root is Tuple, // with no I/O instructions. diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index bf66cc6bc37a5e11c9ecfc07a62ba0ea5ca11a03..f6e2866204955ac024c2b6f972de449cc3df4c15 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -38,9 +38,7 @@ class HloModuleDceTest : public HloTestBase { // Returns whether the given instruction exists in the given computation. bool HasInstruction(const HloComputation& computation, const HloInstruction* instruction) { - return std::find(computation.instructions().begin(), - computation.instructions().end(), - instruction) != computation.instructions().end(); + return absl::c_linear_search(computation.instructions(), instruction); } // Returns whether the while instruction with name 'while_name' in @@ -373,9 +371,9 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) { HloModule OutfeedLoop WhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index b4aac4c8076cb69647d42c6243bc969d06d0709e..b877081be5775bf6c75a69ffeba28d0f2cc17f90 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -79,36 +79,36 @@ Status HloModuleGroupMetadata::Build() { return Status::OK(); } - std::vector peers; - if (IsChannelInstruction(hlo)) { - peers.push_back(PeerComputation(hlo)); - } else if (hlo->IsCrossModuleAllReduce()) { - for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) { - if (instr == hlo) { - continue; + if (IsChannelInstruction(hlo) || hlo->IsCrossModuleAllReduce()) { + std::vector peers; + if (IsChannelInstruction(hlo)) { + peers.push_back(PeerComputation(hlo)); + } else if (hlo->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) { + if (instr == hlo) { + continue; + } + peers.push_back(instr->parent()); } - peers.push_back(instr->parent()); } - } - - // Add the parent computation of this channel (or all-reduce) instruction - // and its peer computation(s) (both must be while computations) as - // companions. - for (HloComputation* peer_computation : peers) { - const TrackedInstruction* peer_tracked = - GetTrackedInstruction(peer_computation); - TF_RET_CHECK(peer_tracked != nullptr) - << "Peer instruction is not a possible companion"; - TF_RET_CHECK(*tracked == *peer_tracked) - << "Peer instruction does not match the computation kind"; - TF_RETURN_IF_ERROR( - AddCompanion(tracked->instruction(), peer_tracked->instruction())); - tracked_instructions_comms_[tracked->instruction()].push_back(hlo); - } - // Add the parents of companion instructions (they must be all of the same - // kind of instructions, opcode wise) as companions. - if (IsCompanionInstruction(hlo)) { + // Add the parent computation of this channel (or all-reduce) instruction + // and its peer computation(s) (both must be while computations) as + // companions. + for (HloComputation* peer_computation : peers) { + const TrackedInstruction* peer_tracked = + GetTrackedInstruction(peer_computation); + TF_RET_CHECK(peer_tracked != nullptr) + << "Peer instruction is not a possible companion"; + TF_RET_CHECK(*tracked == *peer_tracked) + << "Peer instruction does not match the computation kind"; + TF_RETURN_IF_ERROR( + AddCompanion(tracked->instruction(), peer_tracked->instruction())); + tracked_instructions_comms_[tracked->instruction()].push_back(hlo); + } + } else if (IsCompanionInstruction(hlo)) { + // Add the parents of companion instructions (they must be all of the same + // kind of instructions, opcode wise) as companions. for (HloInstruction* companion : Companions(hlo)) { const TrackedInstruction* companion_tracked = GetTrackedInstruction(companion->parent()); @@ -118,6 +118,7 @@ Status HloModuleGroupMetadata::Build() { companion_tracked->instruction())); } } + return Status::OK(); }; @@ -198,7 +199,7 @@ bool HloModuleGroupMetadata::IsChannelInstruction( } bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const { - return companion_set_index_.count(hlo) > 0; + return companion_set_index_.contains(hlo); } bool HloModuleGroupMetadata::InstructionCommunicates( @@ -388,9 +389,10 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, instruction1->opcode() == HloOpcode::kCall); VLOG(2) << "adding as companions:" << instruction1->ToString() << " and " << instruction2->ToString(); - - if (!ContainsKey(companion_set_index_, instruction1) && - !ContainsKey(companion_set_index_, instruction2)) { + if (instruction1 == instruction2) { + return Status::OK(); + } else if (!ContainsKey(companion_set_index_, instruction1) && + !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( absl::make_unique>()); auto companion_set = companion_sets_.back().get(); @@ -418,7 +420,10 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, for (HloInstruction* hlo : Companions(instruction2)) { companion_set_index_[hlo] = companion_set_index_[instruction1]; } - companion_sets_.erase(companion_sets_.begin() + index_to_remove); + // We can't remove the set from the vector because companion_set_index_ + // references sets by their index in this vector, so we reset to nullptr + // instead. + companion_sets_[index_to_remove].reset(nullptr); } return Status::OK(); } @@ -509,7 +514,7 @@ Status HloModuleGroupMetadata::CheckCommunicatingInstruction( HloComputation* computation = instruction->parent(); const HloModule* module = computation->parent(); if (module->entry_computation() == computation || - tracked_instructions_.count(computation) > 0) { + tracked_instructions_.contains(computation)) { return Status::OK(); } return FailedPrecondition("channel is used in disallowed computation"); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 928df0f5a7444ad877961a5de970c752e1d024da..84f7f2f31339ae9e98ea2301b6e6d94fcf4dedbb 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -38,7 +38,7 @@ namespace xla { // Class for bookkeeping the information on the given modules, in particular on // the interaction between computations. // -// Companion instructions are one of the information collected as we build the +// Companion instructions are one piece of information collected as we build the // metadata. For example, for each While instruction, companion instructions // refer to a set of While instructions in other computations that communicate // with each other. @@ -51,6 +51,13 @@ namespace xla { // } While_4() { Recv(0) } // } // +// Each instruction can belong to at most one companion set: While_0 and While_5 +// are in the same set even though they don't communicate with each other, +// because they both communicate with While_2. +// +// A send and the matching recv must both have the same level of nesting of +// companion instructions. +// // Companion instructions are used to detect cycles in the graph and also for // global scheduling. class HloModuleGroupMetadata { @@ -166,12 +173,13 @@ class HloModuleGroupMetadata { // Returns the number of modules for devices (excluding the host module). int64 GetDeviceModulesCount() const; - // Returns the companion instructions for the given instruction. + // Returns the companion set for the given instruction, including the + // instruction itself. // // Precondition: IsCompanionWhile(instruction) is true. const std::vector& Companions( const HloInstruction* instruction) const { - CHECK_EQ(companion_set_index_.count(instruction), 1); + CHECK(companion_set_index_.contains(instruction)); return companion_set(companion_set_index_.at(instruction)); } @@ -215,11 +223,8 @@ class HloModuleGroupMetadata { // * Each channel has all 4 instructions (Send, Recv, SendDone, RecvDone). // * The shape of channel instructions match. // * The nest level of channel instructions match. - // * Channel instructions are used in allowed computations; i.e., in the + // * Channel instructions are used in allowed computations, i.e., in the // entry computation of the module or condition/body of While computations. - // - // TODO(b/62064342): Currently, HloModuleGroupScheduler checks if there is a - // cycle in the graph, but it would be good to verify here. Status VerifyChannelInstructions(); // Adds metadata that the given two instructions are companions. @@ -231,8 +236,8 @@ class HloModuleGroupMetadata { Status CheckCommunicatingInstruction(HloInstruction* instruction) const; // Performs a consistency check on the companion sets built for the input - // modules. Check that a companion set does not include instructions from the - // same module/device. + // modules. Checks that each instruction in a companion set is in a different + // module/device. Status VerifyCompanionSets() const; // Retrieves a pointer to the stored TrackedInstruction associated with a diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index fddeb5f0a27a43ff9ca8b2b5d314bcfe91aaf0e6..91417bd2d9a6ca8a5192a37302e6a91e49a94d77 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -198,6 +198,8 @@ std::vector HloModuleGroupUtil::RootInstructions( for (HloComputation* computation : computations) { for (HloInstruction* instruction : computation->instructions()) { if (GlobalSuccessors(instruction).empty()) { + // An instruction that has no successors, e.g., an unused instruction, + // is in roots, even though it's not the ROOT of its computation. roots.push_back(instruction); } } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h index f21b44bcd98d77b831de5d8a6afa4f9ddd91d15d..862666b48c9aa423ba4eeea3052c17fcc1064fd2 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -49,7 +49,7 @@ class HloModuleGroupUtil { // Returns all unique successors of the instruction. This includes: // * successors in the same computation: users and control successors // * Send is a successor of Recv - // * RecvDone is a predecessor of Send + // * RecvDone is a successor of Send // * successors of companions (if the instruction is a companion while) // * successors' companions (for any successor that is a companion while) std::vector GlobalSuccessors(HloInstruction* instruction); diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 4551a1c2e259b06818f913cb6a9e782436b7e594..548fbb873aa646e061fb990454bb555d098607d8 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -53,8 +53,8 @@ StatusOr StringToHloOpcode(const string& opcode_name) { bool HloOpcodeIsComparison(HloOpcode opcode) { switch (opcode) { -#define CASE_IS_COMPARISON(enum_name, ...) \ - case HloOpcode::enum_name: \ +#define CASE_IS_COMPARISON(enum_name, opcode_name, ...) \ + case HloOpcode::enum_name: \ return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__); HLO_OPCODE_LIST(CASE_IS_COMPARISON) #undef CASE_IS_COMPARISON @@ -63,14 +63,25 @@ bool HloOpcodeIsComparison(HloOpcode opcode) { bool HloOpcodeIsVariadic(HloOpcode opcode) { switch (opcode) { -#define CASE_IS_VARIADIC(enum_name, ...) \ - case HloOpcode::enum_name: \ - return HAS_PROPERTY(kHloOpcodeIsVariadic, __VA_ARGS__); +#define CASE_IS_VARIADIC(enum_name, opcode_name, arity, ...) \ + case HloOpcode::enum_name: \ + return arity == kHloOpcodeIsVariadic; HLO_OPCODE_LIST(CASE_IS_VARIADIC) #undef CASE_IS_VARIADIC } } +absl::optional HloOpcodeArity(HloOpcode opcode) { + switch (opcode) { +#define CASE_ARITY(enum_name, opcode_name, arity, ...) \ + case HloOpcode::enum_name: \ + return arity == kHloOpcodeIsVariadic ? absl::nullopt \ + : absl::make_optional(arity); + HLO_OPCODE_LIST(CASE_ARITY) +#undef CASE_ARITY + } +} + #undef HAS_PROPERTY #undef RESOLVE #undef CHECK_DEFAULT diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 127cfd165a5d8229cac3035f56a66f1bcfa734f3..c571664c81256e8dc319c97ddffa4e0f10609db2 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -30,9 +31,9 @@ namespace xla { // See the XLA documentation for the semantics of each opcode. // // Each entry has the format: -// (enum_name, opcode_name) +// (enum_name, opcode_name, arity) // or -// (enum_name, opcode_name, p1 | p2 | ...) +// (enum_name, opcode_name, arity, p1 | p2 | ...) // // with p1, p2, ... are members of HloOpcodeProperty. They are combined // using bitwise-or. @@ -44,102 +45,106 @@ namespace xla { // - In fully qualified names (HloInstruction::FullyQualifiedName()), to // separate the qualifiers (name of the computation and potentially the // fusion instruction) from the name -#define HLO_OPCODE_LIST(V) \ - V(kAbs, "abs") \ - V(kAdd, "add") \ - V(kAddDependency, "add-dependency") \ - V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ - V(kAllToAll, "all-to-all") \ - V(kAtan2, "atan2") \ - V(kBatchNormGrad, "batch-norm-grad") \ - V(kBatchNormInference, "batch-norm-inference") \ - V(kBatchNormTraining, "batch-norm-training") \ - V(kBitcast, "bitcast") \ - V(kBitcastConvert, "bitcast-convert") \ - V(kBroadcast, "broadcast") \ - V(kCall, "call", kHloOpcodeIsVariadic) \ - V(kCeil, "ceil") \ - V(kClamp, "clamp") \ - V(kCollectivePermute, "collective-permute") \ - V(kClz, "count-leading-zeros") \ - V(kComplex, "complex") \ - V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ - V(kConditional, "conditional") \ - V(kConstant, "constant") \ - V(kConvert, "convert") \ - V(kConvolution, "convolution") \ - V(kCopy, "copy") \ - V(kCos, "cosine") \ - V(kCrossReplicaSum, "cross-replica-sum") \ - V(kCustomCall, "custom-call") \ - V(kDivide, "divide") \ - V(kDomain, "domain") \ - V(kDot, "dot") \ - V(kDynamicSlice, "dynamic-slice") \ - V(kDynamicUpdateSlice, "dynamic-update-slice") \ - V(kEq, "equal-to", kHloOpcodeIsComparison) \ - V(kExp, "exponential") \ - V(kExpm1, "exponential-minus-one") \ - V(kFft, "fft") \ - V(kFloor, "floor") \ - V(kFusion, "fusion", kHloOpcodeIsVariadic) \ - V(kGather, "gather") \ - V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ - V(kGetDimensionSize, "get-dimension-size") \ - V(kGetTupleElement, "get-tuple-element") \ - V(kGt, "greater-than", kHloOpcodeIsComparison) \ - V(kImag, "imag") \ - V(kInfeed, "infeed") \ - V(kIota, "iota") \ - V(kIsFinite, "is-finite") \ - V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \ - V(kLog, "log") \ - V(kLog1p, "log-plus-one") \ - V(kAnd, "and") \ - V(kNot, "not") \ - V(kOr, "or") \ - V(kXor, "xor") \ - V(kLt, "less-than", kHloOpcodeIsComparison) \ - V(kMap, "map", kHloOpcodeIsVariadic) \ - V(kMaximum, "maximum") \ - V(kMinimum, "minimum") \ - V(kMultiply, "multiply") \ - V(kNe, "not-equal-to", kHloOpcodeIsComparison) \ - V(kNegate, "negate") \ - V(kOutfeed, "outfeed") \ - V(kPad, "pad") \ - V(kParameter, "parameter") \ - V(kPower, "power") \ - V(kReal, "real") \ - V(kRecv, "recv") \ - V(kRecvDone, "recv-done") \ - V(kReduce, "reduce") \ - V(kReducePrecision, "reduce-precision") \ - V(kReduceWindow, "reduce-window") \ - V(kRemainder, "remainder") \ - V(kReshape, "reshape") \ - V(kReverse, "reverse") \ - V(kRng, "rng") \ - V(kRoundNearestAfz, "round-nearest-afz") \ - V(kScatter, "scatter") \ - V(kSelect, "select") \ - V(kSelectAndScatter, "select-and-scatter") \ - V(kSend, "send") \ - V(kSendDone, "send-done") \ - V(kShiftLeft, "shift-left") \ - V(kShiftRightArithmetic, "shift-right-arithmetic") \ - V(kShiftRightLogical, "shift-right-logical") \ - V(kSign, "sign") \ - V(kSin, "sine") \ - V(kSlice, "slice") \ - V(kSort, "sort") \ - V(kSubtract, "subtract") \ - V(kTanh, "tanh") \ - V(kTrace, "trace") \ - V(kTranspose, "transpose") \ - V(kTuple, "tuple", kHloOpcodeIsVariadic) \ - V(kTupleSelect, "tuple-select") \ - V(kWhile, "while") +#define HLO_OPCODE_LIST(V) \ + V(kAbs, "abs", 1) \ + V(kAdd, "add", 2) \ + V(kAddDependency, "add-dependency", 2) \ + V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ + V(kAllReduce, "all-reduce", kHloOpcodeIsVariadic) \ + V(kAllToAll, "all-to-all", kHloOpcodeIsVariadic) \ + V(kAtan2, "atan2", 2) \ + V(kBatchNormGrad, "batch-norm-grad", 5) \ + V(kBatchNormInference, "batch-norm-inference", 5) \ + V(kBatchNormTraining, "batch-norm-training", 3) \ + V(kBitcast, "bitcast", 1) \ + V(kBitcastConvert, "bitcast-convert", 1) \ + V(kBroadcast, "broadcast", 1) \ + V(kCall, "call", kHloOpcodeIsVariadic) \ + V(kCeil, "ceil", 1) \ + V(kClamp, "clamp", 3) \ + V(kCollectivePermute, "collective-permute", 1) \ + V(kClz, "count-leading-zeros", 1) \ + V(kComplex, "complex", 2) \ + V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ + V(kConditional, "conditional", 3) \ + V(kConstant, "constant", 0) \ + V(kConvert, "convert", 1) \ + V(kConvolution, "convolution", 2) \ + V(kCopy, "copy", 1) \ + V(kCos, "cosine", 1) \ + V(kCustomCall, "custom-call", kHloOpcodeIsVariadic) \ + V(kDivide, "divide", 2) \ + V(kDomain, "domain", 1) \ + V(kDot, "dot", 2) \ + V(kDynamicSlice, "dynamic-slice", kHloOpcodeIsVariadic) \ + V(kDynamicUpdateSlice, "dynamic-update-slice", kHloOpcodeIsVariadic) \ + V(kEq, "equal-to", 2, kHloOpcodeIsComparison) \ + V(kExp, "exponential", 1) \ + V(kExpm1, "exponential-minus-one", 1) \ + V(kFft, "fft", 1) \ + V(kFloor, "floor", 1) \ + V(kFusion, "fusion", kHloOpcodeIsVariadic) \ + V(kGather, "gather", 2) \ + V(kGe, "greater-than-or-equal-to", 2, kHloOpcodeIsComparison) \ + V(kGetDimensionSize, "get-dimension-size", 1) \ + V(kGetTupleElement, "get-tuple-element", 1) \ + V(kGt, "greater-than", 2, kHloOpcodeIsComparison) \ + V(kImag, "imag", 1) \ + V(kInfeed, "infeed", 1) \ + V(kIota, "iota", 0) \ + V(kIsFinite, "is-finite", 1) \ + V(kLe, "less-than-or-equal-to", 2, kHloOpcodeIsComparison) \ + V(kLog, "log", 1) \ + V(kLog1p, "log-plus-one", 1) \ + V(kAnd, "and", 2) \ + V(kNot, "not", 1) \ + V(kOr, "or", 2) \ + V(kXor, "xor", 2) \ + V(kLt, "less-than", 2, kHloOpcodeIsComparison) \ + V(kMap, "map", kHloOpcodeIsVariadic) \ + V(kMaximum, "maximum", 2) \ + V(kMinimum, "minimum", 2) \ + V(kMultiply, "multiply", 2) \ + V(kNe, "not-equal-to", 2, kHloOpcodeIsComparison) \ + V(kNegate, "negate", 1) \ + V(kOutfeed, "outfeed", 2) \ + V(kPad, "pad", 2) \ + V(kParameter, "parameter", 0) \ + V(kPower, "power", 2) \ + V(kReal, "real", 1) \ + V(kRecv, "recv", 1) \ + V(kRecvDone, "recv-done", 1) \ + V(kReduce, "reduce", kHloOpcodeIsVariadic) \ + V(kReducePrecision, "reduce-precision", 1) \ + V(kReduceWindow, "reduce-window", 2) \ + V(kRemainder, "remainder", 2) \ + V(kReplicaId, "replica-id", 0) \ + V(kReshape, "reshape", 1) \ + V(kReverse, "reverse", 1) \ + V(kRng, "rng", kHloOpcodeIsVariadic) \ + V(kRoundNearestAfz, "round-nearest-afz", 1) \ + V(kRsqrt, "rsqrt", 1) \ + V(kScatter, "scatter", 3) \ + V(kSelect, "select", 3) \ + V(kSelectAndScatter, "select-and-scatter", 3) \ + V(kSend, "send", 2) \ + V(kSendDone, "send-done", 1) \ + V(kShiftLeft, "shift-left", 2) \ + V(kShiftRightArithmetic, "shift-right-arithmetic", 2) \ + V(kShiftRightLogical, "shift-right-logical", 2) \ + V(kSign, "sign", 1) \ + V(kSin, "sine", 1) \ + V(kSlice, "slice", 1) \ + V(kSort, "sort", kHloOpcodeIsVariadic) \ + V(kSqrt, "sqrt", 1) \ + V(kSubtract, "subtract", 2) \ + V(kTanh, "tanh", 1) \ + V(kTrace, "trace", 1) \ + V(kTranspose, "transpose", 1) \ + V(kTriangularSolve, "triangular-solve", 2) \ + V(kTuple, "tuple", kHloOpcodeIsVariadic) \ + V(kTupleSelect, "tuple-select", 3) \ + V(kWhile, "while", 1) enum class HloOpcode { #define DECLARE_ENUM(enum_name, opcode_name, ...) enum_name, @@ -147,12 +152,16 @@ enum class HloOpcode { #undef DECLARE_ENUM }; +// Arity value that denotes that an operator is variadic. +enum { + kHloOpcodeIsVariadic = -1, +}; + // List of properties associated with opcodes. // Properties are defined as increasing powers of two, so that we can use // bitwise-or to combine properties, and bitwise-and to test for them. enum HloOpcodeProperty { kHloOpcodeIsComparison = 1 << 0, - kHloOpcodeIsVariadic = 1 << 1, }; // Returns a string representation of the opcode. @@ -171,6 +180,10 @@ bool HloOpcodeIsComparison(HloOpcode opcode); // Returns true iff the given opcode has variadic operands. bool HloOpcodeIsVariadic(HloOpcode opcode); +// Returns the arity of opcode. If the opcode is variadic, +// returns nullopt. +absl::optional HloOpcodeArity(HloOpcode opcode); + // Returns the number of HloOpcode values. inline const uint32_t HloOpcodeCount() { #define HLO_COUNT_ONE(...) +1 diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index 6f3f83f63a05fafaa3f3ddcff8a7cac7cb7b06d5..c599690f44e4eb2713c287e9f3d89a658771032f 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -54,11 +54,19 @@ TEST(HloOpcodeTest, OpcodeProperties) { EXPECT_FALSE(HloOpcodeIsComparison(opcode)); } switch (opcode) { + case HloOpcode::kAfterAll: + case HloOpcode::kAllReduce: + case HloOpcode::kAllToAll: case HloOpcode::kCall: case HloOpcode::kConcatenate: + case HloOpcode::kCustomCall: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kFusion: case HloOpcode::kMap: - case HloOpcode::kAfterAll: + case HloOpcode::kReduce: + case HloOpcode::kRng: + case HloOpcode::kSort: case HloOpcode::kTuple: EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); break; diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index ca6a154809be46d6a0305c29e2b89219de408019..0cec61c257bb84e467290fb52ec9063a32ed558d 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -367,7 +367,7 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( const HloInstruction* a, const HloInstruction* b) const { CHECK_EQ(a->parent(), b->parent()); // If either instruction is not in the order, then 'a' and 'b' are unordered. - if (order_position_.count(a) == 0 || order_position_.count(b) == 0) { + if (!order_position_.contains(a) || !order_position_.contains(b)) { return false; } return order_position_.at(a) < order_position_.at(b); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 9b5bb5d0bd6af104ef62eaa5d3e53cedbe0213d3..4aa1090f48af0d674eb816cf0823395f08cc3836 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include #include "absl/algorithm/container.h" #include "absl/memory/memory.h" @@ -21,10 +22,13 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/types/variant.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" @@ -43,8 +47,6 @@ using absl::StrCat; using absl::StrFormat; using absl::StrJoin; -const double kF16max = 65504; - // Creates and returns a schedule created using the order of the instructions in // the HloComputation::instructions() vectors in the module. HloSchedule ScheduleFromInstructionOrder(HloModule* module) { @@ -59,6 +61,10 @@ HloSchedule ScheduleFromInstructionOrder(HloModule* module) { return schedule; } +// Some functions accept either a linear index or a multi-dimensional index +// (used for indexing into sparse literals). +using LinearOrMultiIndex = absl::variant>; + // Parser for the HloModule::ToString() format text. class HloParser { public: @@ -74,7 +80,9 @@ class HloParser { string GetError() const { return StrJoin(error_, "\n"); } // Stand alone parsing utils for various aggregate data types. + StatusOr ParseShapeOnly(); StatusOr ParseShardingOnly(); + StatusOr> ParseParameterReplicationOnly(); StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); StatusOr ParsePaddingConfigOnly(); @@ -100,7 +108,7 @@ class HloParser { // Parse a single instruction worth of text. bool ParseSingleInstruction(HloModule* module); - // ParseXXX returns false if an error occurred. + // Parses a module, returning false if an error occurred. bool ParseHloModule(HloModule* module); bool ParseComputations(HloModule* module); @@ -116,21 +124,30 @@ class HloParser { bool ParseNonTupleLiteral(Literal* literal, const Shape& shape); bool ParseDenseLiteral(Literal* literal, const Shape& shape); bool ParseSparseLiteral(Literal* literal, const Shape& shape); - template - bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape); - // Sets the sub-value of literal at the given index to the given value. The - // literal's shape must have the default layout. - bool SetValueInLiteral(tensorflow::int64 value, - tensorflow::int64 linear_index, Literal* literal); - bool SetValueInLiteral(double value, tensorflow::int64 linear_index, + // Sets the sub-value of literal at the given linear or sparse index to the + // given value. If the literal is dense, it myst have the default layout. + // + // `loc` should be the source location of the value. + bool SetValueInLiteral(LocTy loc, int64 value, LinearOrMultiIndex index, + Literal* literal); + bool SetValueInLiteral(LocTy loc, double value, LinearOrMultiIndex index, Literal* literal); - bool SetValueInLiteral(bool value, tensorflow::int64 linear_index, + bool SetValueInLiteral(LocTy loc, bool value, LinearOrMultiIndex index, Literal* literal); + bool SetValueInLiteral(LocTy loc, std::complex value, + LinearOrMultiIndex index, Literal* literal); + // `loc` should be the source location of the value. + template + bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value, + LinearOrMultiIndex index, Literal* literal); + + // Checks whether the given value is within the range of LiteralNativeT. + // `loc` should be the source location of the value. template - bool SetValueInLiteralHelper(ParsedElemT value, - tensorflow::int64 linear_index, - Literal* literal); + bool CheckParsedValueIsInRange(LocTy loc, ParsedElemT value); + template + bool CheckParsedValueIsInRange(LocTy loc, std::complex value); bool ParseOperands(std::vector* operands); // Fills parsed operands into 'operands' and expects a certain number of @@ -141,9 +158,9 @@ class HloParser { // Describes the start, limit, and stride on every dimension of the operand // being sliced. struct SliceRanges { - std::vector starts; - std::vector limits; - std::vector strides; + std::vector starts; + std::vector limits; + std::vector strides; }; // The data parsed for the kDomain instruction. @@ -163,9 +180,11 @@ class HloParser { kBracedInt64ListList, kHloComputation, kFftType, + kTriangularSolveTranspose, kWindow, kConvolutionDimensionNumbers, kSharding, + kParameterReplication, kInstructionList, kSliceRanges, kPaddingConfig, @@ -230,21 +249,21 @@ class HloParser { bool ParseMetadata(OpMetadata* metadata); bool ParseSharding(OpSharding* sharding); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + bool ParseParameterReplication(ParameterReplication* parameter_replication); // Parses the metadata behind a kDOmain instruction. bool ParseDomain(DomainData* domain); // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. - bool ParseDxD(const string& name, std::vector* result); + bool ParseDxD(const string& name, std::vector* result); // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. - bool ParseWindowPad(std::vector>* pad); + bool ParseWindowPad(std::vector>* pad); bool ParseSliceRanges(SliceRanges* result); bool ParsePrecisionList(std::vector* result); bool ParseShapeList(std::vector* result); bool ParseInt64List(const TokKind start, const TokKind end, - const TokKind delim, - std::vector* result); + const TokKind delim, std::vector* result); // 'parse_and_add_item' is an lambda to parse an element in the list and add // the parsed element to the result. It's supposed to capture the result. bool ParseList(const TokKind start, const TokKind end, const TokKind delim, @@ -255,14 +274,20 @@ class HloParser { bool ParseName(string* result); bool ParseAttributeName(string* result); bool ParseString(string* result); + bool ParseDimensionSizes(std::vector* dimension_sizes, + std::vector* dynamic_dimensions); bool ParseShape(Shape* result); + bool ParseLayout(Layout* layout); + bool ParseTiles(std::vector* tiles); bool ParseOpcode(HloOpcode* result); bool ParseFftType(FftType* result); + bool ParseTriangularSolveTranspose(TriangularSolveOptions::Transpose* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); bool ParsePrecision(PrecisionConfig::Precision* result); - bool ParseInt64(tensorflow::int64* result); + bool ParseInt64(int64* result); bool ParseDouble(double* result); + bool ParseComplex(std::complex* result); bool ParseBool(bool* result); bool ParseToken(TokKind kind, const string& msg); @@ -279,9 +304,6 @@ class HloParser { // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. bool EatIfPresent(TokKind kind); - // Parses a shape, and returns true if the result is compatible with the given - // shape. - bool EatShapeAndCheckCompatible(const Shape& shape); // Adds the instruction to the pool. Returns false and emits an error if the // instruction already exists. @@ -625,6 +647,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, std::unordered_map attrs; optional sharding; attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; + optional parameter_replication; + attrs["parameter_replication"] = {/*required=*/false, + AttrTy::kParameterReplication, + ¶meter_replication}; optional> predecessors; attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList, &predecessors}; @@ -638,11 +664,17 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { - tensorflow::int64 parameter_number; + int64 parameter_number; if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || - !ParseInt64(¶meter_number) || - !ParseToken(TokKind::kRparen, "expects ')' after parameter number") || + !ParseInt64(¶meter_number)) { + return false; + } + if (parameter_number < 0) { + Error(lexer_.GetLoc(), "parameter number must be >= 0"); + return false; + } + if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") || !ParseAttributes(attrs)) { return false; } @@ -664,7 +696,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kIota: { - optional iota_dimension; + optional iota_dimension; attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64, &iota_dimension}; if (!ParseOperands(&operands, /*expected_size=*/0) || @@ -693,8 +725,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: + case HloOpcode::kRsqrt: case HloOpcode::kSign: case HloOpcode::kSin: + case HloOpcode::kSqrt: case HloOpcode::kTanh: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -766,7 +800,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, HloInstruction::CreateBitcastConvert(shape, operands[0])); break; } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kAllReduce: { optional>> tmp_groups; optional to_apply; optional> replica_group_ids; @@ -786,10 +820,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, if (tmp_groups) { replica_groups = CreateReplicaGroups(*tmp_groups); } - instruction = - builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, replica_groups, - barrier ? *barrier : "", all_reduce_id)); + instruction = builder->AddInstruction(HloInstruction::CreateAllReduce( + shape, operands, *to_apply, replica_groups, barrier ? *barrier : "", + all_reduce_id)); break; } case HloOpcode::kAllToAll: { @@ -829,6 +862,14 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, HloInstruction::CreateCollectivePermute(shape, operands[0], pairs)); break; } + case HloOpcode::kReplicaId: { + if (!ParseOperands(&operands, /*expected_size=*/0) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateReplicaId()); + break; + } case HloOpcode::kReshape: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -860,17 +901,21 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kSort: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; + optional is_stable = false; + attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable}; + optional to_apply; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; if (!ParseOperands(&operands) || !ParseAttributes(attrs) || dimensions->size() != 1) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), - /*keys=*/operands[0], - /*values=*/absl::Span(operands).subspan(1))); + instruction = builder->AddInstruction( + HloInstruction::CreateSort(shape, dimensions->at(0), operands, + to_apply.value(), is_stable.value())); break; } case HloOpcode::kTuple: { @@ -896,7 +941,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kRecv: { - optional channel_id; + optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; @@ -912,7 +957,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kRecvDone: { - optional channel_id; + optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; @@ -930,7 +975,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kSend: { - optional channel_id; + optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; @@ -945,7 +990,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kSendDone: { - optional channel_id; + optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; @@ -963,7 +1008,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kGetTupleElement: { - optional index; + optional index; attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -1006,11 +1051,14 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional window; optional dnums; optional feature_group_count; + optional batch_group_count; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/true, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64, + &batch_group_count}; optional> operand_precision; attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, &operand_precision}; @@ -1024,6 +1072,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, if (!feature_group_count) { feature_group_count = 1; } + if (!batch_group_count) { + batch_group_count = 1; + } PrecisionConfig precision_config; if (operand_precision) { *precision_config.mutable_operand_precision() = { @@ -1034,12 +1085,13 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, } instruction = builder->AddInstruction(HloInstruction::CreateConvolve( shape, /*lhs=*/operands[0], /*rhs=*/operands[1], - feature_group_count.value(), *window, *dnums, precision_config)); + feature_group_count.value(), batch_group_count.value(), *window, + *dnums, precision_config)); break; } case HloOpcode::kFft: { optional fft_type; - optional> fft_length; + optional> fft_length; attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type}; attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List, &fft_length}; @@ -1051,8 +1103,40 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, shape, operands[0], *fft_type, *fft_length)); break; } + case HloOpcode::kTriangularSolve: { + optional left_side; + optional lower; + optional unit_diagonal; + optional transpose_a; + attrs["left_side"] = {/*required=*/false, AttrTy::kBool, &left_side}; + attrs["lower"] = {/*required=*/false, AttrTy::kBool, &lower}; + attrs["unit_diagonal"] = {/*required=*/false, AttrTy::kBool, + &unit_diagonal}; + attrs["transpose_a"] = {/*required=*/false, + AttrTy::kTriangularSolveTranspose, &transpose_a}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + TriangularSolveOptions options; + if (left_side) { + options.set_left_side(*left_side); + } + if (lower) { + options.set_lower(*lower); + } + if (unit_diagonal) { + options.set_unit_diagonal(*unit_diagonal); + } + options.set_transpose_a( + transpose_a ? *transpose_a : TriangularSolveOptions::NO_TRANSPOSE); + instruction = + builder->AddInstruction(HloInstruction::CreateTriangularSolve( + shape, operands[0], operands[1], options)); + break; + } case HloOpcode::kBroadcast: { - optional> broadcast_dimensions; + optional> broadcast_dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &broadcast_dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -1064,7 +1148,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kConcatenate: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs) || @@ -1079,7 +1163,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional to_apply; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { @@ -1095,7 +1179,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional reduce_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; - optional> dimensions_to_reduce; + optional> dimensions_to_reduce; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { @@ -1116,7 +1200,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kReverse: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -1160,31 +1244,46 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kDynamicSlice: { - optional> dynamic_slice_sizes; + optional> dynamic_slice_sizes; attrs["dynamic_slice_sizes"] = { /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { + LocTy loc = lexer_.GetLoc(); + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } + if (operands.empty()) { + return Error(loc, "Expected at least one operand."); + } + if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) && + operands.size() != 1 + operands[0]->shape().rank()) { + return Error(loc, "Wrong number of operands."); + } instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice( - shape, /*operand=*/operands[0], /*start_indices=*/operands[1], + shape, /*operand=*/operands[0], + /*start_indices=*/absl::MakeSpan(operands).subspan(1), *dynamic_slice_sizes)); break; } case HloOpcode::kDynamicUpdateSlice: { - if (!ParseOperands(&operands, /*expected_size=*/3) || - !ParseAttributes(attrs)) { + LocTy loc = lexer_.GetLoc(); + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } + if (operands.size() < 2) { + return Error(loc, "Expected at least two operands."); + } + if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) && + operands.size() != 2 + operands[0]->shape().rank()) { + return Error(loc, "Wrong number of operands."); + } instruction = builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( shape, /*operand=*/operands[0], /*update=*/operands[1], - /*start_indices=*/operands[2])); + /*start_indices=*/absl::MakeSpan(operands).subspan(2))); break; } case HloOpcode::kTranspose: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -1198,7 +1297,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kBatchNormTraining: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/3) || @@ -1214,7 +1313,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kBatchNormInference: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -1231,7 +1330,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kBatchNormGrad: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -1280,7 +1379,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail // if the shape is not a non-empty tuple, so add guard so an error message // can be emitted instead of a check fail - if (!ShapeUtil::IsTuple(shape) && !ShapeUtil::IsEmptyTuple(shape)) { + if (!shape.IsTuple() && !ShapeUtil::IsEmptyTuple(shape)) { return Error(lexer_.GetLoc(), "infeed must have a non-empty tuple shape"); } @@ -1313,8 +1412,8 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kReducePrecision: { - optional exponent_bits; - optional mantissa_bits; + optional exponent_bits; + optional mantissa_bits; attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, &exponent_bits}; attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, @@ -1352,6 +1451,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional window; optional dnums; optional feature_group_count; + optional batch_group_count; optional> operand_layout_constraints; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; @@ -1361,6 +1461,8 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64, + &batch_group_count}; attrs["operand_layout_constraints"] = { /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { @@ -1416,19 +1518,22 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, if (feature_group_count.has_value()) { instruction->set_feature_group_count(*feature_group_count); } + if (batch_group_count.has_value()) { + instruction->set_batch_group_count(*batch_group_count); + } break; } case HloOpcode::kDot: { - optional> lhs_contracting_dims; + optional> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims}; - optional> rhs_contracting_dims; + optional> rhs_contracting_dims; attrs["rhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims}; - optional> lhs_batch_dims; + optional> lhs_batch_dims; attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &lhs_batch_dims}; - optional> rhs_batch_dims; + optional> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; optional> operand_precision; @@ -1472,19 +1577,19 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kGather: { - optional> offset_dims; + optional> offset_dims; attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, &offset_dims}; - optional> collapsed_slice_dims; + optional> collapsed_slice_dims; attrs["collapsed_slice_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims}; - optional> start_index_map; + optional> start_index_map; attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List, &start_index_map}; - optional index_vector_dim; + optional index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; - optional> slice_sizes; + optional> slice_sizes; attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List, &slice_sizes}; @@ -1506,17 +1611,17 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kScatter: { - optional> update_window_dims; + optional> update_window_dims; attrs["update_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims}; - optional> inserted_window_dims; + optional> inserted_window_dims; attrs["inserted_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims}; - optional> scatter_dims_to_operand_dims; + optional> scatter_dims_to_operand_dims; attrs["scatter_dims_to_operand_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, &scatter_dims_to_operand_dims}; - optional index_vector_dim; + optional index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; @@ -1557,7 +1662,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); case HloOpcode::kGetDimensionSize: - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -1582,6 +1687,18 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, instruction->set_sharding( HloSharding::FromProto(sharding.value()).ValueOrDie()); } + if (parameter_replication) { + int leaf_count = ShapeUtil::GetLeafCount(instruction->shape()); + const auto& replicated = + parameter_replication->replicated_at_leaf_buffers(); + if (leaf_count != replicated.size()) { + return Error(lexer_.GetLoc(), + StrCat("parameter has ", leaf_count, + " leaf buffers, but parameter_replication has ", + replicated.size(), " elements.")); + } + instruction->set_parameter_replicated_at_leaf_buffers(replicated); + } if (predecessors) { for (auto* pre : *predecessors) { Status status = pre->AddControlDependencyTo(instruction); @@ -1646,8 +1763,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; - std::vector devices; - std::vector tile_assignment_dimensions; + std::vector devices; + std::vector tile_assignment_dimensions; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { case TokKind::kw_maximal: @@ -1673,7 +1790,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } do { - tensorflow::int64 dim; + int64 dim; if (!ParseInt64(&dim)) { return false; } @@ -1685,7 +1802,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return false; } do { - tensorflow::int64 device; + int64 device; if (!ParseInt64(&device)) { return false; } @@ -1697,11 +1814,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } break; } - case TokKind::kShape: - // TODO(b/112302613): Left here for backward compatibility to ignore the - // removed tile shape data. - lexer_.Lex(); - break; case TokKind::kRbrace: break; default: @@ -1734,10 +1846,10 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, "dimensions"); } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); - for (tensorflow::int64 dim : tile_assignment_dimensions) { + for (int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } - for (tensorflow::int64 device : devices) { + for (int64 device : devices) { sharding->add_tile_assignment_devices(device); } } @@ -1746,6 +1858,32 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return true; } +// parameter_replication ::= +// '{' ('true' | 'false')* (',' ('true' | 'false'))* '}' +bool HloParser::ParseParameterReplication( + ParameterReplication* parameter_replication) { + if (!ParseToken(TokKind::kLbrace, + "expected '{' to start parameter_replication attribute")) { + return false; + } + + if (lexer_.GetKind() != TokKind::kRbrace) { + do { + if (lexer_.GetKind() == TokKind::kw_true) { + parameter_replication->add_replicated_at_leaf_buffers(true); + } else if (lexer_.GetKind() == TokKind::kw_false) { + parameter_replication->add_replicated_at_leaf_buffers(false); + } else { + return false; + } + lexer_.Lex(); + } while (EatIfPresent(TokKind::kComma)); + } + + return ParseToken(TokKind::kRbrace, + "expected '}' to end parameter_replication attribute"); +} + // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ',' // 'exit=' exit_sharding '}' bool HloParser::ParseDomain(DomainData* domain) { @@ -1798,142 +1936,145 @@ bool HloParser::ParseInstructionNames( "expects '}' at the end of instruction name list"); } -bool HloParser::SetValueInLiteral(tensorflow::int64 value, - tensorflow::int64 linear_index, - Literal* literal) { +bool HloParser::SetValueInLiteral(LocTy loc, int64 value, + LinearOrMultiIndex index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case S8: - return SetValueInLiteralHelper(value, linear_index, - literal); + return SetValueInLiteralHelper(loc, value, index, literal); case S16: - return SetValueInLiteralHelper(value, linear_index, - literal); + return SetValueInLiteralHelper(loc, value, index, literal); case S32: - return SetValueInLiteralHelper(value, linear_index, - literal); + return SetValueInLiteralHelper(loc, value, index, literal); case S64: - return SetValueInLiteralHelper(value, linear_index, - literal); + return SetValueInLiteralHelper(loc, value, index, literal); case U8: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case U16: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case U32: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case U64: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case PRED: // Bool type literals with rank >= 1 are printed in 0s and 1s. - return SetValueInLiteralHelper(static_cast(value), - linear_index, literal); + return SetValueInLiteralHelper(loc, static_cast(value), index, + literal); default: LOG(FATAL) << "unknown integral primitive type " << PrimitiveType_Name(shape.element_type()); } } -bool HloParser::SetValueInLiteral(double value, tensorflow::int64 linear_index, - Literal* literal) { +bool HloParser::SetValueInLiteral(LocTy loc, double value, + LinearOrMultiIndex index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case F16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(loc, value, index, literal); case BF16: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case F32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(loc, value, index, literal); case F64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(loc, value, index, literal); default: LOG(FATAL) << "unknown floating point primitive type " << PrimitiveType_Name(shape.element_type()); } } -bool HloParser::SetValueInLiteral(bool value, tensorflow::int64 linear_index, - Literal* literal) { +bool HloParser::SetValueInLiteral(LocTy loc, bool value, + LinearOrMultiIndex index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case PRED: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(loc, value, index, literal); default: LOG(FATAL) << PrimitiveType_Name(shape.element_type()) << " is not PRED type"; } } +bool HloParser::SetValueInLiteral(LocTy loc, std::complex value, + LinearOrMultiIndex index, Literal* literal) { + const Shape& shape = literal->shape(); + switch (shape.element_type()) { + case C64: + return SetValueInLiteralHelper>(loc, value, index, + literal); + case C128: + return SetValueInLiteralHelper>(loc, value, index, + literal); + default: + LOG(FATAL) << PrimitiveType_Name(shape.element_type()) + << " is not a complex type type"; + } +} + +template +string StringifyValue(T val) { + return StrCat(val); +} +template <> +string StringifyValue(std::complex val) { + return StrFormat("(%f, %f)", std::real(val), std::imag(val)); +} + template -bool HloParser::SetValueInLiteralHelper(ParsedElemT value, - tensorflow::int64 linear_index, +bool HloParser::SetValueInLiteralHelper(LocTy loc, ParsedElemT value, + LinearOrMultiIndex index, Literal* literal) { - // Check that linear_index is in range. - if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) { - return TokenError( - StrCat("trys to set value ", value, " to a literal in shape ", - ShapeUtil::HumanString(literal->shape()), " at linear index ", - linear_index, ", but the index is out of range")); + if (!CheckParsedValueIsInRange(loc, value)) { + return false; } - if (std::isnan(value) || - (std::numeric_limits::has_infinity && - (std::numeric_limits::infinity() == value || - -std::numeric_limits::infinity() == value))) { - // Skip range checking for non-finite value. - } else if (literal->shape().element_type() == F16 || - literal->shape().element_type() == BF16) { - if (value > kF16max || value < -kF16max) { - return TokenError(StrCat( - "value ", value, " is out of range for literal's primitive type ", - PrimitiveType_Name(literal->shape().element_type()))); + // Check that the index is in range and assign into the literal + if (auto* linear_index = absl::get_if(&index)) { + if (*linear_index >= ShapeUtil::ElementsIn(literal->shape())) { + return Error(loc, StrCat("trys to set value ", StringifyValue(value), + " to a literal in shape ", + ShapeUtil::HumanString(literal->shape()), + " at linear index ", *linear_index, + ", but the index is out of range")); } - } else if (std::is_unsigned::value) { - CHECK((std::is_same::value || - std::is_same::value)) - << "Unimplemented checking for ParsedElemT"; - - ParsedElemT upper_bound; - if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) { - upper_bound = std::numeric_limits::max(); - } else { - upper_bound = - static_cast(std::numeric_limits::max()); - } - if (value > upper_bound || value < 0) { - // Value is out of range for LiteralNativeT. - return TokenError(StrCat( - "value ", value, " is out of range for literal's primitive type ", - PrimitiveType_Name(literal->shape().element_type()))); - } - } else if (value > static_cast( - std::numeric_limits::max()) || - value < static_cast( - std::numeric_limits::lowest())) { - // Value is out of range for LiteralNativeT. - return TokenError(StrCat( - "value ", value, " is out of range for literal's primitive type ", - PrimitiveType_Name(literal->shape().element_type()))); - } + literal->data().at(*linear_index) = + static_cast(value); + } else { + auto* multi_index = absl::get_if>(&index); + CHECK(multi_index != nullptr); - literal->data().at(linear_index) = - static_cast(value); - return true; -} + auto invalid_idx = [&](string msg) { + return Error(loc, StrFormat("Invalid sparse index [%s]. %s", + absl::StrJoin(*multi_index, ", "), msg)); + }; -bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { - Shape new_shape; - if (!ParseShape(&new_shape)) { - return TokenError(StrCat("expects shape ", ShapeUtil::HumanString(shape))); - } - if (!ShapeUtil::Compatible(shape, new_shape)) { - return TokenError(StrCat( - "expects shape ", ShapeUtil::HumanString(shape), - ", but sees a different shape: ", ShapeUtil::HumanString(new_shape))); + const auto& shape = literal->shape(); + if (shape.rank() != multi_index->size()) { + return invalid_idx( + StrFormat("Has rank %d, but constant has shape %s, which has rank %d", + multi_index->size(), shape.ToString(), shape.rank())); + } + for (int64 i = 0; i < shape.rank(); ++i) { + auto idx = (*multi_index)[i]; + if (idx < 0) { + return invalid_idx(StrFormat( + "Sub-index value at %d, namely %d, cannot be negative.", i, idx)); + } + if (idx >= shape.dimensions(i)) { + return invalid_idx( + StrFormat("Sub-index at %d, namely %d, doesn't fit within shape " + "dimension %d in %s", + i, idx, shape.dimensions(i), shape.ToString())); + } + } + literal->AppendSparseElement(*multi_index, + static_cast(value)); } return true; } @@ -1942,8 +2083,8 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { // ::= tuple // ::= non_tuple bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { - return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape) - : ParseNonTupleLiteral(literal, shape); + return shape.IsTuple() ? ParseTupleLiteral(literal, shape) + : ParseNonTupleLiteral(literal, shape); } // tuple @@ -1952,10 +2093,6 @@ bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { // ::= /*empty*/ // ::= literal (',' literal)* bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { - if (!EatShapeAndCheckCompatible(shape)) { - return TokenError(StrCat("expects tuple constant in shape ", - ShapeUtil::HumanString(shape))); - } if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { return false; } @@ -1990,21 +2127,21 @@ bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { return ParseSparseLiteral(literal, shape); } - CHECK(LayoutUtil::IsDenseArray(shape)); + CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true); return ParseDenseLiteral(literal, shape); } bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { - const tensorflow::int64 rank = ShapeUtil::Rank(shape); - if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { - return false; - } + // Cast `rank` to int because we call shape.dimensions(int rank) below, and if + // `rank` is an int64, that's an implicit narrowing conversion, which is + // implementation-defined behavior. + const int rank = static_cast(shape.rank()); // Create a literal with the given shape in default layout. *literal = LiteralUtil::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); - tensorflow::int64 nest_level = 0; - tensorflow::int64 linear_index = 0; + int64 nest_level = 0; + int64 linear_index = 0; // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}}, // when we are parsing the 2nd '{' (right before '1'), we are seeing a @@ -2012,17 +2149,35 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { // the first '}' (right after '3'), it means the sub-array ends, and the // sub-array is supposed to contain exactly 3 elements, so check if // elems_seen_per_dim[1] is 3. - std::vector elems_seen_per_dim(rank); + std::vector elems_seen_per_dim(rank); auto get_index_str = [&elems_seen_per_dim](int dim) -> string { - std::vector elems_seen_until_dim( - elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); + std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), + elems_seen_per_dim.begin() + dim); return StrCat("[", StrJoin(elems_seen_until_dim, ",", - [](string* out, const tensorflow::int64& num_elems) { + [](string* out, const int64& num_elems) { StrAppend(out, num_elems - 1); }), "]"); }; + + auto add_one_elem_seen = [&] { + if (rank > 0) { + if (nest_level != rank) { + return TokenError(absl::StrFormat( + "expects nested array in rank %d, but sees %d", rank, nest_level)); + } + elems_seen_per_dim[rank - 1]++; + if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) { + return TokenError(absl::StrFormat( + "expects %d elements on the minor-most dimension, but " + "sees more", + shape.dimensions(rank - 1))); + } + } + return true; + }; + do { switch (lexer_.GetKind()) { default: @@ -2058,6 +2213,31 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { lexer_.Lex(); break; } + case TokKind::kLparen: { + if (!primitive_util::IsComplexType(shape.element_type())) { + return TokenError( + absl::StrFormat("unexpected '(' in literal. Parens are only " + "valid for complex literals")); + } + + std::complex value; + LocTy loc = lexer_.GetLoc(); + if (!add_one_elem_seen() || !ParseComplex(&value) || + !SetValueInLiteral(loc, value, linear_index++, literal)) { + return false; + } + break; + } + case TokKind::kDots: { + if (nest_level != 1) { + return TokenError(absl::StrFormat( + "expects `...` at nest level 1, but sees it at nest level %d", + nest_level)); + } + elems_seen_per_dim[0] = shape.dimensions(0); + lexer_.Lex(); + break; + } case TokKind::kComma: // Skip. lexer_.Lex(); @@ -2069,23 +2249,11 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { case TokKind::kw_nan: case TokKind::kw_inf: case TokKind::kNegInf: { - if (rank > 0) { - if (nest_level != rank) { - return TokenError( - absl::StrFormat("expects nested array in rank %d, but sees %d", - rank, nest_level)); - } - elems_seen_per_dim[rank - 1]++; - if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) { - return TokenError(absl::StrFormat( - "expects %d elements on the minor-most dimension, but " - "sees more", - shape.dimensions(rank - 1))); - } - } + add_one_elem_seen(); if (lexer_.GetKind() == TokKind::kw_true || lexer_.GetKind() == TokKind::kw_false) { - if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, + if (!SetValueInLiteral(lexer_.GetLoc(), + lexer_.GetKind() == TokKind::kw_true, linear_index++, literal)) { return false; } @@ -2093,12 +2261,12 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { } else if (primitive_util::IsIntegralType(shape.element_type()) || shape.element_type() == PRED) { LocTy loc = lexer_.GetLoc(); - tensorflow::int64 value; + int64 value; if (!ParseInt64(&value)) { return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal)) { + if (!SetValueInLiteral(loc, value, linear_index++, literal)) { return false; } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { @@ -2109,7 +2277,7 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { loc, StrCat("expect floating point value for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal)) { + if (!SetValueInLiteral(loc, value, linear_index++, literal)) { return false; } } else { @@ -2126,52 +2294,7 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { } bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { - if (!EatShapeAndCheckCompatible(shape)) { - return false; - } - - switch (shape.element_type()) { - case PRED: - return ParseSparseLiteralHelper(literal, shape); - case S8: - return ParseSparseLiteralHelper(literal, shape); - case S16: - return ParseSparseLiteralHelper(literal, shape); - case S32: - return ParseSparseLiteralHelper(literal, shape); - case S64: - return ParseSparseLiteralHelper(literal, shape); - case U8: - return ParseSparseLiteralHelper(literal, shape); - case U16: - return ParseSparseLiteralHelper(literal, shape); - case U32: - return ParseSparseLiteralHelper(literal, shape); - case U64: - return ParseSparseLiteralHelper(literal, shape); - case F16: - return ParseSparseLiteralHelper(literal, shape); - case F32: - return ParseSparseLiteralHelper(literal, shape); - case BF16: - return ParseSparseLiteralHelper(literal, shape); - case F64: - return ParseSparseLiteralHelper(literal, shape); - default: - return Error(lexer_.GetLoc(), - StrCat("invalid primitive type for sparse literal: ", - PrimitiveType_Name(shape.element_type()))); - } -} - -template -bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { - std::vector index; - - tensorflow::int64 rank = ShapeUtil::Rank(shape); - *literal = Literal(shape); - if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { return false; @@ -2183,61 +2306,66 @@ bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { break; } - LocTy index_loc = lexer_.GetLoc(); - index.clear(); + std::vector index; if (lexer_.GetKind() == TokKind::kInt) { - tensorflow::int64 single_index = lexer_.GetInt64Val(); + int64 single_index = lexer_.GetInt64Val(); lexer_.Lex(); - if (rank != 1) { - return Error( - index_loc, - StrCat("invalid single-dimensional index for shape with rank ", - rank, ": ", single_index)); - } index.push_back(single_index); } else { if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma, &index)) { return false; } - if (index.size() != rank) { - return Error( - index_loc, - StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", StrJoin(index, ", "), "]")); - } } if (!ParseToken(TokKind::kColon, "expects ':' after after the sparse array index and before " "the sparse array value")) { return false; } + LocTy value_loc = lexer_.GetLoc(); - LiteralNativeT value; if (lexer_.GetKind() == TokKind::kw_true || lexer_.GetKind() == TokKind::kw_false) { - value = static_cast(lexer_.GetKind() == TokKind::kw_true); + bool value = lexer_.GetKind() == TokKind::kw_true; + if (!SetValueInLiteral(lexer_.GetLoc(), value, index, literal)) { + return false; + } lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { - tensorflow::int64 value_s64; - if (!ParseInt64(&value_s64)) { + int64 value; + if (!ParseInt64(&value)) { return Error(value_loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - value = static_cast(value_s64); + if (!SetValueInLiteral(value_loc, value, index, literal)) { + return false; + } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { - double value_f64; - if (!ParseDouble(&value_f64)) { + double value; + if (!ParseDouble(&value)) { return Error(value_loc, StrCat("expects floating point value for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - value = static_cast(value_f64); + if (!SetValueInLiteral(value_loc, value, index, literal)) { + return false; + } + } else if (primitive_util::IsComplexType(shape.element_type())) { + std::complex value; + if (!ParseComplex(&value)) { + return Error(value_loc, + StrCat("expects complex value for primitive type: ", + PrimitiveType_Name(shape.element_type()))); + } + if (!SetValueInLiteral(value_loc, value, index, literal)) { + return false; + } } else { LOG(FATAL) << "Unexpected element type: " << PrimitiveType_Name(shape.element_type()); } + if (lexer_.GetKind() != TokKind::kRbrace && !ParseToken(TokKind::kComma, "expects ',' separator between sparse array elements")) { @@ -2251,14 +2379,114 @@ bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { StrCat("number of sparse elements exceeds maximum for layout: ", ShapeUtil::HumanStringWithLayout(shape))); } - - literal->AppendSparseElement(index, value); } literal->SortSparseElements(); return true; } +// MaxFiniteValue is a type-traits helper used by +// HloParser::CheckParsedValueIsInRange. +template +struct MinMaxFiniteValue { + static T max() { return std::numeric_limits::max(); } + static T min() { return std::numeric_limits::lowest(); } +}; + +template <> +struct MinMaxFiniteValue { + static double max() { + // Sadly this is not constexpr, so this forces `value` to be a method. + return static_cast(Eigen::NumTraits::highest()); + } + static double min() { return -max(); } +}; + +template <> +struct MinMaxFiniteValue { + static double max() { return static_cast(bfloat16::highest()); } + static double min() { return -max(); } +}; + +template +bool HloParser::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) { + PrimitiveType literal_ty = + primitive_util::NativeToPrimitiveType(); + if (std::isnan(value) || + (std::numeric_limits::has_infinity && + (std::numeric_limits::infinity() == value || + -std::numeric_limits::infinity() == value))) { + // Skip range checking for non-finite value. + } else if (std::is_unsigned::value) { + CHECK((std::is_same::value || + std::is_same::value)) + << "Unimplemented checking for ParsedElemT"; + + ParsedElemT upper_bound; + if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) { + upper_bound = std::numeric_limits::max(); + } else { + upper_bound = + static_cast(std::numeric_limits::max()); + } + if (value > upper_bound || value < 0) { + // Value is out of range for LiteralNativeT. + return Error(loc, StrCat("value ", value, + " is out of range for literal's primitive type ", + PrimitiveType_Name(literal_ty), " namely [0, ", + upper_bound, "].")); + } + } else if (value > MinMaxFiniteValue::max() || + value < MinMaxFiniteValue::min()) { + // Value is out of range for LiteralNativeT. + return Error(loc, StrCat("value ", value, + " is out of range for literal's primitive type ", + PrimitiveType_Name(literal_ty), " namely [", + MinMaxFiniteValue::min(), ", ", + MinMaxFiniteValue::max(), "].")); + } + return true; +} + +template +bool HloParser::CheckParsedValueIsInRange(LocTy loc, + std::complex value) { + // e.g. `float` for std::complex + using LiteralComplexComponentT = + decltype(std::real(std::declval())); + + // We could do simply + // + // return CheckParsedValueIsInRange(std::real(value)) && + // CheckParsedValueIsInRange(std::imag(value)); + // + // but this would give bad error messages on failure. + + auto check_component = [&](absl::string_view name, double v) { + if (std::isnan(v) || v == std::numeric_limits::infinity() || + v == -std::numeric_limits::infinity()) { + // Skip range-checking for non-finite values. + return true; + } + + double min = MinMaxFiniteValue::min(); + double max = MinMaxFiniteValue::max(); + if (v < min || v > max) { + // Value is out of range for LitearlComplexComponentT. + return Error( + loc, + StrCat(name, " part ", v, + " is out of range for literal's primitive type ", + PrimitiveType_Name( + primitive_util::NativeToPrimitiveType()), + ", namely [", min, ", ", max, "].")); + } + return true; + }; + return check_component("real", std::real(value)) && + check_component("imaginary", std::imag(value)); +} + // operands ::= '(' operands1 ')' // operands1 // ::= /*empty*/ @@ -2416,24 +2644,23 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kInt64: { - tensorflow::int64 result; + int64 result; if (!ParseInt64(&result)) { return false; } - static_cast*>(attr_out_ptr) - ->emplace(result); + static_cast*>(attr_out_ptr)->emplace(result); return true; } case AttrTy::kInt32: { - tensorflow::int64 result; + int64 result; if (!ParseInt64(&result)) { return false; } - if (result != static_cast(result)) { + if (result != static_cast(result)) { return Error(attr_loc, "value out of range for int32"); } - static_cast*>(attr_out_ptr) - ->emplace(static_cast(result)); + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); return true; } case AttrTy::kFloat: { @@ -2473,6 +2700,15 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(result); return true; } + case AttrTy::kTriangularSolveTranspose: { + TriangularSolveOptions::Transpose result; + if (!ParseTriangularSolveTranspose(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } case AttrTy::kWindow: { Window result; if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) { @@ -2498,6 +2734,15 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(sharding); return true; } + case AttrTy::kParameterReplication: { + ParameterReplication parameter_replication; + if (!ParseParameterReplication(¶meter_replication)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(parameter_replication); + return true; + } case AttrTy::kInstructionList: { std::vector result; if (!ParseInstructionNames(&result)) { @@ -2517,19 +2762,19 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kBracedInt64List: { - std::vector result; + std::vector result; if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, &result)) { return false; } - static_cast>*>(attr_out_ptr) + static_cast>*>(attr_out_ptr) ->emplace(result); return true; } case AttrTy::kBracedInt64ListList: { - std::vector> result; + std::vector> result; auto parse_and_add_item = [&]() { - std::vector item; + std::vector item; if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, &item)) { return false; @@ -2541,8 +2786,7 @@ bool HloParser::ParseAttributeHelper( parse_and_add_item)) { return false; } - static_cast>>*>( - attr_out_ptr) + static_cast>>*>(attr_out_ptr) ->emplace(result); return true; } @@ -2743,7 +2987,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( absl::string_view rhs = split2[0]; absl::string_view out = split2[1]; - const tensorflow::int64 rank = lhs.length(); + const int64 rank = lhs.length(); if (rank != rhs.length() || rank != out.length()) { return TokenError( "convolution lhs, rhs, and output must have the same rank"); @@ -2753,7 +2997,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( } auto is_unique = [](string str) -> bool { - std::sort(str.begin(), str.end()); + absl::c_sort(str); return std::unique(str.begin(), str.end()) == str.end(); }; @@ -2854,7 +3098,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { return false; } - std::vector> ranges; + std::vector> ranges; if (lexer_.GetKind() == TokKind::kRbrace) { // empty return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); @@ -2924,9 +3168,9 @@ bool HloParser::ParseShapeList(std::vector* result) { // ::= int64_val (delim int64_val)* bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, - std::vector* result) { + std::vector* result) { auto parse_and_add_item = [&]() { - tensorflow::int64 i; + int64 i; if (!ParseInt64(&i)) { return false; } @@ -2994,6 +3238,136 @@ bool HloParser::ParseParamList() { return ParseToken(TokKind::kRparen, "expects ')' at the end of param list"); } +// dimension_sizes ::= '[' dimension_list ']' +// dimension_list +// ::= /*empty*/ +// ::= <=? int64 (',' param)* +// param ::= name shape +bool HloParser::ParseDimensionSizes(std::vector* dimension_sizes, + std::vector* dynamic_dimensions) { + auto parse_and_add_item = [&]() { + int64 i; + bool is_dynamic = false; + if (lexer_.GetKind() == TokKind::kLeq) { + is_dynamic = true; + lexer_.Lex(); + } + if (!ParseInt64(&i)) { + return false; + } + dimension_sizes->push_back(i); + dynamic_dimensions->push_back(is_dynamic); + return true; + }; + return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma, + parse_and_add_item); +} + +// tiles +// ::= /*empty*/ +// ::= 'T' '(' dim_list ')' +// dim_list +// ::= /*empty*/ +// ::= (int64 | '*') (',' (int64 | '*'))* +bool HloParser::ParseTiles(std::vector* tiles) { + auto parse_and_add_tile_dimension = [&]() { + tensorflow::int64 i; + if (ParseInt64(&i)) { + tiles->back().add_dimensions(i); + return true; + } + if (lexer_.GetKind() == TokKind::kAsterisk) { + tiles->back().add_dimensions(Tile::kCombineDimension); + lexer_.Lex(); + return true; + } + return false; + }; + + do { + tiles->push_back(Tile()); + if (!ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma, + parse_and_add_tile_dimension)) { + return false; + } + } while (lexer_.GetKind() == TokKind::kLparen); + return true; +} + +// layout ::= '{' int64_list (':' tiles element_size_in_bits)? '}' +// element_size_in_bits +// ::= /*empty*/ +// ::= 'E' '(' int64 ')' +bool HloParser::ParseLayout(Layout* layout) { + std::vector minor_to_major; + std::vector tiles; + tensorflow::int64 element_size_in_bits = 0; + + auto parse_and_add_item = [&]() { + int64 i; + if (!ParseInt64(&i)) { + return false; + } + minor_to_major.push_back(i); + return true; + }; + + if (!ParseToken(TokKind::kLbrace, + StrCat("expects layout to start with ", + TokKindToString(TokKind::kLbrace)))) { + return false; + } + if (lexer_.GetKind() != TokKind::kRbrace) { + if (lexer_.GetKind() == TokKind::kInt) { + // Parse minor to major. + do { + if (!parse_and_add_item()) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + } + + if (lexer_.GetKind() == TokKind::kColon) { + lexer_.Lex(); + if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "T") { + lexer_.Lex(); + ParseTiles(&tiles); + } + + if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "E") { + // Parse element size in bits. + lexer_.Lex(); + if (!ParseToken(TokKind::kLparen, + StrCat("expects element size in bits to start with ", + TokKindToString(TokKind::kLparen)))) { + return false; + } + if (!ParseInt64(&element_size_in_bits)) { + return false; + } + if (!ParseToken(TokKind::kRparen, + StrCat("expects element size in bits to end with ", + TokKindToString(TokKind::kRparen)))) { + return false; + } + } + } + } + if (!ParseToken(TokKind::kRbrace, + StrCat("expects layout to end with ", + TokKindToString(TokKind::kRbrace)))) { + return false; + } + + std::vector vec_tiles(tiles.size()); + for (int i = 0; i < tiles.size(); i++) { + vec_tiles[i] = Tile(tiles[i]); + } + *layout = + LayoutUtil::MakeLayout(minor_to_major, vec_tiles, element_size_in_bits); + return true; +} + // shape ::= shape_val_ // shape ::= '(' tuple_elements ')' // tuple_elements @@ -3017,19 +3391,74 @@ bool HloParser::ParseShape(Shape* result) { return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple."); } - if (lexer_.GetKind() != TokKind::kShape) { - return TokenError(absl::StrCat("expected shape, saw ", + if (lexer_.GetKind() != TokKind::kPrimitiveType) { + return TokenError(absl::StrCat("expected primitive type, saw ", TokKindToString(lexer_.GetKind()))); } - *result = lexer_.GetShapeVal(); + PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal(); lexer_.Lex(); + + // Each element contains a dimension size and a bool indicating whether this + // is a dynamic dimension. + std::vector dimension_sizes; + std::vector dynamic_dimensions; + if (!ParseDimensionSizes(&dimension_sizes, &dynamic_dimensions)) { + return false; + } + result->set_element_type(primitive_type); + for (int i = 0; i < dimension_sizes.size(); ++i) { + result->add_dimensions(dimension_sizes[i]); + result->set_dynamic_dimension(i, dynamic_dimensions[i]); + } + LayoutUtil::SetToDefaultLayout(result); + + if (lexer_.GetKind() == TokKind::kw_sparse) { + lexer_.Lex(); + const string message = + "expects a brace-bracketed integer for sparse layout"; + int64 max_sparse_elements; + if (!ParseToken(TokKind::kLbrace, message) || + !ParseInt64(&max_sparse_elements) || + !ParseToken(TokKind::kRbrace, message)) { + return false; + } + *result->mutable_layout() = + LayoutUtil::MakeSparseLayout(max_sparse_elements); + return true; + } + + // We need to lookahead to see if a following open brace is the start of a + // layout. The specific problematic case is: + // + // ENTRY %foo (x: f32[42]) -> f32[123] { + // ... + // } + // + // The open brace could either be the start of a computation or the start of a + // layout for the f32[123] shape. We consider it the start of a layout if the + // next token after the open brace is an integer or a colon. + if (lexer_.GetKind() == TokKind::kLbrace && + (lexer_.LookAhead() == TokKind::kInt || + lexer_.LookAhead() == TokKind::kColon)) { + Layout layout; + if (!ParseLayout(&layout)) { + return false; + } + if (layout.minor_to_major_size() != result->rank()) { + return Error( + lexer_.GetLoc(), + StrFormat("Dimensions size is %ld, but minor to major size is %ld.", + result->rank(), layout.minor_to_major_size())); + } + *result->mutable_layout() = layout; + } return true; } bool HloParser::CanBeShape() { - // A non-tuple shape starts with a kShape token; a tuple shape starts with - // '('. - return lexer_.GetKind() == TokKind::kShape || + // A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts + // with '('. + return lexer_.GetKind() == TokKind::kPrimitiveType || lexer_.GetKind() == TokKind::kLparen; } @@ -3063,15 +3492,14 @@ bool HloParser::ParseString(string* result) { return true; } -bool HloParser::ParseDxD(const string& name, - std::vector* result) { +bool HloParser::ParseDxD(const string& name, std::vector* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { return Error(loc, StrFormat("sub-attribute '%s=' already exists", name)); } // 1D if (lexer_.GetKind() == TokKind::kInt) { - tensorflow::int64 number; + int64 number; if (!ParseInt64(&number)) { return Error(loc, StrFormat("expects sub-attribute '%s=i'", name)); } @@ -3090,8 +3518,7 @@ bool HloParser::ParseDxD(const string& name, return TokenError("expects token type kInt or kDxD"); } -bool HloParser::ParseWindowPad( - std::vector>* pad) { +bool HloParser::ParseWindowPad(std::vector>* pad) { LocTy loc = lexer_.GetLoc(); if (!pad->empty()) { return Error(loc, "sub-attribute 'pad=' already exists"); @@ -3101,7 +3528,7 @@ bool HloParser::ParseWindowPad( } string str = lexer_.GetStrVal(); for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { - std::vector low_high; + std::vector low_high; if (!SplitToInt64s(padding_dim_str, '_', &low_high) || low_high.size() != 2) { return Error(loc, @@ -3124,7 +3551,7 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { LocTy loc = lexer_.GetLoc(); string str = lexer_.GetStrVal(); for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { - std::vector padding_dim; + std::vector padding_dim; if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { return Error(loc, @@ -3146,7 +3573,7 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { optional op_type; optional op_name; optional source_file; - optional source_line; + optional source_line; attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; @@ -3198,6 +3625,22 @@ bool HloParser::ParseFftType(FftType* result) { return true; } +bool HloParser::ParseTriangularSolveTranspose( + TriangularSolveOptions::Transpose* result) { + VLOG(1) << "ParseTriangularSolveTranspose"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects triangular solve transpose type"); + } + string val = lexer_.GetStrVal(); + if (!TriangularSolveOptions_Transpose_Parse(val, result) || + !TriangularSolveOptions_Transpose_IsValid(*result)) { + return TokenError( + StrFormat("expects triangular solve transpose type but sees: %s", val)); + } + lexer_.Lex(); + return true; +} + bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { VLOG(1) << "ParseFusionKind"; if (lexer_.GetKind() != TokKind::kIdent) { @@ -3249,7 +3692,7 @@ bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) { return true; } -bool HloParser::ParseInt64(tensorflow::int64* result) { +bool HloParser::ParseInt64(int64* result) { VLOG(1) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { return TokenError("expects integer"); @@ -3261,9 +3704,18 @@ bool HloParser::ParseInt64(tensorflow::int64* result) { bool HloParser::ParseDouble(double* result) { switch (lexer_.GetKind()) { - case TokKind::kDecimal: - *result = lexer_.GetDecimalVal(); + case TokKind::kDecimal: { + double val = lexer_.GetDecimalVal(); + // If GetDecimalVal returns +/-inf, that means that we overflowed + // `double`. + if (std::isinf(val)) { + return TokenError(StrCat("Constant is out of range for double (+/-", + std::numeric_limits::max(), + ") and so is unparsable.")); + } + *result = val; break; + } case TokKind::kInt: *result = static_cast(lexer_.GetInt64Val()); break; @@ -3283,6 +3735,42 @@ bool HloParser::ParseDouble(double* result) { return true; } +bool HloParser::ParseComplex(std::complex* result) { + if (lexer_.GetKind() != TokKind::kLparen) { + return TokenError("expects '(' before complex number"); + } + lexer_.Lex(); + + double real; + LocTy loc = lexer_.GetLoc(); + if (!ParseDouble(&real)) { + return Error(loc, + "expect floating-point value for real part of complex number"); + } + + if (lexer_.GetKind() != TokKind::kComma) { + return TokenError( + absl::StrFormat("expect comma after real part of complex literal")); + } + lexer_.Lex(); + + double imag; + loc = lexer_.GetLoc(); + if (!ParseDouble(&imag)) { + return Error( + loc, + "expect floating-point value for imaginary part of complex number"); + } + + if (lexer_.GetKind() != TokKind::kRparen) { + return TokenError(absl::StrFormat("expect ')' after complex number")); + } + + *result = std::complex(real, imag); + lexer_.Lex(); + return true; +} + bool HloParser::ParseBool(bool* result) { if (lexer_.GetKind() != TokKind::kw_true && lexer_.GetKind() != TokKind::kw_false) { @@ -3332,6 +3820,18 @@ bool HloParser::AddComputation(const string& name, HloComputation* computation, return true; } +StatusOr HloParser::ParseShapeOnly() { + lexer_.Lex(); + Shape shape; + if (!ParseShape(&shape)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after shape"); + } + return shape; +} + StatusOr HloParser::ParseShardingOnly() { lexer_.Lex(); OpSharding op_sharding; @@ -3344,6 +3844,21 @@ StatusOr HloParser::ParseShardingOnly() { return HloSharding::FromProto(op_sharding); } +StatusOr> HloParser::ParseParameterReplicationOnly() { + lexer_.Lex(); + ParameterReplication parameter_replication; + if (!ParseParameterReplication(¶meter_replication)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument( + "Syntax error:\nExtra content after parameter replication"); + } + return std::vector( + parameter_replication.replicated_at_leaf_buffers().begin(), + parameter_replication.replicated_at_leaf_buffers().end()); +} + StatusOr HloParser::ParseWindowOnly() { lexer_.Lex(); Window window; @@ -3459,6 +3974,11 @@ StatusOr ParseSharding(absl::string_view str) { return parser.ParseShardingOnly(); } +StatusOr> ParseParameterReplication(absl::string_view str) { + HloParser parser(str); + return parser.ParseParameterReplicationOnly(); +} + StatusOr ParseWindow(absl::string_view str) { HloParser parser(str); return parser.ParseWindowOnly(); @@ -3475,4 +3995,9 @@ StatusOr ParsePaddingConfig(absl::string_view str) { return parser.ParsePaddingConfigOnly(); } +StatusOr ParseShape(absl::string_view str) { + HloParser parser(str); + return parser.ParseShapeOnly(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index d830fa61438239005875f785f85cf2486123ebc9..a96260b4d75e515a4cb23d315444142cae1b9587 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -44,11 +44,16 @@ Status ParseHloString(absl::string_view str, HloModule* module); // creates a HloModule with default config. StatusOr> ParseHloString(absl::string_view str); -// ParseHloString sharding from str. str is supposed to contain the body of the -// sharding, i.e. just the rhs of the "sharding={...}" attribute string, -// e.g., "{replicated}". +// Parses sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g., +// "{replicated}". StatusOr ParseSharding(absl::string_view str); +// Parses parameter replication from str. str is supposed to contain the body of +// the parameter replication, i.e. just the rhs of the +// "parameter_replication={...}" attribute string, e.g., "{true, false}". +StatusOr> ParseParameterReplication(absl::string_view str); + // Parses the result of window_util::ToString(const Window&). StatusOr ParseWindow(absl::string_view str); @@ -60,6 +65,9 @@ StatusOr ParseConvolutionDimensionNumbers( // Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". StatusOr ParsePaddingConfig(absl::string_view str); +// Parses and returns a Shape::ToString-format string. +StatusOr ParseShape(absl::string_view str); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index ab71f011ac9d77d00ddfb41aca7a224d26d416b7..8e3f1e44b9562334130aa565ed447a78899fad53 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -63,6 +63,19 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) } +)" +}, +// parameter replication +{ +"ParamReplication", +R"(HloModule param_replication_module + +ENTRY %param_replication (a: f32[], b: (f32[2,4], (f32[2,4]))) -> (f32[], (f32[2,4], (f32[2,4]))) { + %a = f32[] parameter(0), parameter_replication={true} + %b = (f32[2,4]{1,0}, (f32[2,4]{1,0})) parameter(1), parameter_replication={false,true} + ROOT %tuple = (f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0}))) tuple(f32[] %a, (f32[2,4]{1,0}, (f32[2,4]{1,0})) %b) +} + )" }, // pred constant @@ -82,7 +95,7 @@ ENTRY %constant_pred () -> pred[] { R"(HloModule module ENTRY %constant_pred_array () -> pred[2,3] { - ROOT %constant = pred[2,3]{1,0} constant(pred[2,3] { { 0, 1, 0 }, { 1, 0, 1 } }) + ROOT %constant = pred[2,3]{1,0} constant({ { 0, 1, 0 }, { 1, 0, 1 } }) } )" @@ -128,7 +141,7 @@ ENTRY %ConstantF32Empty.v4 () -> f32[0] { R"(HloModule ConstantF32R4Empty_module ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { - ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant(f32[2,0,4,3] { { /*i0=0*/ }, { /*i0=1*/ } }) + ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant({ { /*i0=0*/ }, { /*i0=1*/ } }) } )" @@ -139,7 +152,7 @@ ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { R"(HloModule Small_3x2x1x1_module ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] { - ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) + ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) } )" @@ -196,7 +209,7 @@ ENTRY %add_constants () -> f32[] { R"(HloModule TupleConstant_module ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { - ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { {1}, {2} }, {2, 42} )) + ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant(( { {1}, {2} }, {2, 42} )) } )" @@ -295,11 +308,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { R"(HloModule TwoSendRecvBothWayRecvFist_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, sharding={maximal device=1} + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1} ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1} %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0} } @@ -310,11 +323,11 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { R"(HloModule HostTransferSendRecv_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, is_host_transfer=true + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, is_host_transfer=true + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true } @@ -327,7 +340,7 @@ R"(HloModule GetTupleElement_module ENTRY %GetTupleElement.v4 () -> s32[2,3] { %constant = f32[3]{0} constant({1, 2, 3}) - %constant.1 = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 4, 5, 6 } }) + %constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }) %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1) ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0} } @@ -434,7 +447,7 @@ ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f R"(HloModule Reverse4DFloatArrayOnDim01_module ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { - %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) + %constant = f32[4,3,2,1]{0,1,2,3} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1} } @@ -446,8 +459,8 @@ ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { R"(HloModule Concat2x3With2x5_module ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { - %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } }) - %constant.1 = f32[2,5]{1,0} constant(f32[2,5] { { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } }) + %constant = f32[2,3]{1,0} constant({ { 0, 1, 2 }, { 1000, 1001, 1002 } }) + %constant.1 = f32[2,5]{1,0} constant({ { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } }) ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1} } @@ -471,8 +484,8 @@ R"(HloModule R4F32OverlapSmall_module } ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { - %constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) - %constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) + %constant = f32[4,5,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) + %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) %constant.2 = f32[] constant(0) ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3 } @@ -523,7 +536,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { R"(HloModule Slice3x3x3_To_1x3x3_F32_module ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { - %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) + %constant = f32[3,3,3]{2,1,0} constant({ { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]} } @@ -547,10 +560,21 @@ ENTRY %SliceR0.v2 () -> s32[] { R"(HloModule Transpose_module ENTRY %Transpose.v2 () -> s32[1,2,3] { - %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } }) + %constant = s32[1,2,3]{2,1,0} constant({ { { 1, 2, 3 }, { 4, 5, 6 } } }) ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2} } +)" +}, +{ +"TransposeC128", +R"(HloModule TransposeC128_module + +ENTRY %Transpose.v3 (input: c128[1,2,3]) -> c128[1,2,3] { + %input = c128[1,2,3]{2,1,0} parameter(0) + ROOT %transpose = c128[1,2,3]{2,1,0} transpose(c128[1,2,3]{2,1,0} %input), dimensions={0,1,2} +} + )" }, // Dynamic slice @@ -566,12 +590,26 @@ ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) - ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258} } +)" +}, +// Dynamic slice with scalar indices +{ +"DynamicSliceScalarIndices", +R"(HloModule DynamicSlice_module + +ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] { + %original_parameter = s32[2,2,258]{2,1,0} parameter(0) + %constant = s32[] constant(0) + %start_index = s32[] parameter(1) + ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258} +} + )" }, // Dynamic update slice { "DynamicUpdateSlice", -R"(HloModule DynamicUpdateSlice_module +R"(HloModule DynamicSlice_module ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] { %input = s32[1,1,25,1]{3,2,1,0} parameter(0) @@ -580,6 +618,23 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices) } +)" +}, +// Dynamic update slice with scalar indices +{ +"DynamicUpdateSliceScalarIndex", +R"(HloModule DynamicUpdateSlice_module + +ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] { + %input = s32[1,1,25,1]{3,2,1,0} parameter(0) + %update = s32[1,1,2,1]{3,2,1,0} parameter(1) + %start_index.0 = s32[] parameter(2) + %start_index.1 = s32[] parameter(3) + %start_index.2 = s32[] parameter(4) + %start_index.3 = s32[] parameter(5) + ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3) +} + )" }, // batch norm training @@ -588,7 +643,7 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ R"(HloModule BasicTraining_module ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { - %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } }) + %constant = f32[2,2,1,2]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } }) %constant.1 = f32[2]{0} constant({2, 3}) %constant.2 = f32[2]{0} constant({1, 2}) ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3 @@ -728,7 +783,7 @@ R"(HloModule fusion_module } ENTRY %fusion.v3 () -> f32[3,2,1,1] { - %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) + %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) %constant.1 = f32[2]{0} constant({3.14, 4.25}) ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation } @@ -740,7 +795,17 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] { R"(HloModule sparse_f32 ENTRY %sparse () -> f32[2,3,4] { - ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3}) + ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 2]: 2, [1, 2, 3]: 3}) +} + +)" +}, +{ +"SparseC128", +R"(HloModule sparse_c128 + +ENTRY %sparse () -> c128[2,3,4] { + ROOT %foo = c128[2,3,4]sparse{10} constant({[0, 1, 2]: (1, 0), [1, 2, 2]: (2, 5), [1, 2, 3]: (3, 10)}) } )" @@ -750,7 +815,7 @@ ENTRY %sparse () -> f32[2,3,4] { R"(HloModule sparse_f32_empty ENTRY %sparse_f32_empty () -> f32[2,3,4] { - ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{}) + ROOT %foo = f32[2,3,4]sparse{10} constant({}) } )" @@ -760,7 +825,7 @@ ENTRY %sparse_f32_empty () -> f32[2,3,4] { R"(HloModule sparse_f32_r1 ENTRY %sparse_f32_r1 () -> f32[9] { - ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6}) + ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6}) } )" @@ -852,6 +917,28 @@ ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123 ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}} } +)" +}, +// Parse c64 literal +{ +"ParseC64Literal", +R"(HloModule ParseC64Literal + +ENTRY %ParseC64Literal () -> c64[2] { + ROOT %c = c64[2]{0} constant({(1, 2), (-inf, nan)}) +} + +)" +}, +// Parse c128 literal +{ +"ParseC128Literal", +R"(HloModule ParseC128Literal + +ENTRY %ParseC128Literal () -> c128[2] { + ROOT %c = c128[2]{0} constant({(1, 2), (-inf, nan)}) +} + )" }, }); @@ -931,11 +1018,11 @@ ENTRY reduce_entry { R"(HloModule outfeed_module ENTRY InfeedToOutfeed { - token = token[] after-all() - infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + token0 = token[] after-all() + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) - ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + outfeed = token[] outfeed(infeed.data, token0) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 infeed.1.token = token[] get-tuple-element(infeed.1), index=1 outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) @@ -973,9 +1060,15 @@ ENTRY ReducePrecision { "SortKey", R"(HloModule sort +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY Sort { x = f32[1024]{0} parameter(0) - ROOT sorted = f32[1024]{0} sort(x), dimensions={0} + ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, to_apply=compare } )" @@ -985,10 +1078,18 @@ ENTRY Sort { "SortKeyValue", R"(HloModule sort +compare { + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY Sort { keys = f32[1024]{0} parameter(0) values = s32[1024]{0} parameter(1) - ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare } )" @@ -998,9 +1099,15 @@ ENTRY Sort { "SortKeyR2", R"(HloModule sort +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY Sort { x = f32[1024,16]{0,1} parameter(0) - ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0} + ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0}, to_apply=compare } )" @@ -1010,10 +1117,18 @@ ENTRY Sort { "SortKeyValueR2", R"(HloModule sort +compare { + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY Sort { keys = f32[1024,16]{0,1} parameter(0) values = s32[1024,16]{0,1} parameter(1) - ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0} + ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0}, to_apply=compare } )" @@ -1023,12 +1138,42 @@ ENTRY Sort { "SortManyValues", R"(HloModule sort +compare { + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.2.lhs = u32[] parameter(4) + p.2.rhs = u32[] parameter(5) + p.3.lhs = f32[] parameter(6) + p.3.rhs = f32[] parameter(7) + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY Sort { keys = f32[1024,16]{0,1} parameter(0) values.0 = s32[1024,16]{0,1} parameter(1) values.1 = u32[1024,16]{0,1} parameter(2) values.2 = f32[1024,16]{0,1} parameter(3) - ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0} + ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare +} + +)" +}, +// Sort (Key) is_stable=true +{ +"SortKeyStable", +R"(HloModule sort + +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + +ENTRY Sort { + x = f32[1024]{0} parameter(0) + ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare } )" @@ -1117,9 +1262,9 @@ ENTRY Gather { )" }, -// cross-replica-sum +// all-reduce { -"CrossReplicaSum", +"AllReduce", R"(HloModule CRS add { @@ -1130,14 +1275,14 @@ add { ENTRY CRS { input = f32[8]{0} parameter(0) - ROOT crs = f32[8]{0} cross-replica-sum(input), replica_groups={}, to_apply=add + ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, to_apply=add } )" }, -// cross-replica-sum with subgroups +// all-reduce with subgroups { -"CrossReplicaSumWithSubgroups", +"AllReduceWithSubgroups", R"(HloModule CRS_Subgroups add { @@ -1146,16 +1291,16 @@ add { ROOT add = f32[] add(lhs, rhs) } -ENTRY CrossReplicaSumWithSubgroups { +ENTRY AllReduceWithSubgroups { input = f32[128,32]{0,1} parameter(0) - ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add + ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add } )" }, -// cross-replica-sum with all-reduce-id +// all-reduce with all-reduce-id { -"CrossReplicaSumAllReduce", +"AllReduceAllReduce", R"(HloModule CRS add { @@ -1166,8 +1311,8 @@ add { ENTRY CRS { input = f32[8]{0} parameter(0) - crs.1 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add - ROOT crs.0 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add + crs.1 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add + ROOT crs.0 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add } )" @@ -1206,6 +1351,17 @@ ENTRY CollectivePermute { ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} } +)" +}, +// replica-id +{ +"ReplicaId", +R"(HloModule replica-id + +ENTRY Replica-id { + ROOT replica-id = u32[] replica-id() +} + )" }, // Iota @@ -1235,10 +1391,18 @@ ENTRY Computation { "ScheduledModule", R"(HloModule scheduled_module, is_scheduled=true +compare { + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lhs = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY Sort { keys = f32[1024]{0} parameter(0) values = s32[1024]{0} parameter(1) - ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare } )" @@ -1266,12 +1430,36 @@ R"(HloModule AddDependency ENTRY AddDependency { p = f32[] parameter(0) neg = f32[] negate(p) - token = token[] after-all(neg) - p_after_token = f32[] add-dependency(p, token) + token0 = token[] after-all(neg) + p_after_token = f32[] add-dependency(p, token0) exp = f32[] exponential(p_after_token) ROOT sum = f32[] add(neg, exp) } +)" +}, + +// A module containing constants equal to the min/max values of various data +// types. +{ +"MinMaxValues", +R"(HloModule MinMaxValues + +ENTRY MinMaxValues { + x.s8 = s8[2]{0} constant({-128, 127}) + x.s16 = s16[2]{0} constant({-32768, 32767}) + x.s32 = s32[2]{0} constant({-2147483648, 2147483647}) + x.u8 = u8[2]{0} constant({0, 255}) + x.u16 = u16[2]{0} constant({0, 65535}) + x.u32 = u32[2]{0} constant({0, 4294967295}) + x.f16 = f16[2]{0} constant({-65504, 65504}) + x.bf16 = bf16[2]{0} constant({-3.38953e+38, 3.38953e+38}) + x.f32 = f32[2]{0} constant({-3.40282e+38, 3.40282e+38}) + x.f64 = f64[2]{0} constant({-1.79769e+308, 1.79769e+308}) + x.c64 = c64[2]{0} constant({(-3.40282e+38, 3.40282e+38), (3.40282e+38, -3.40282e+38)}) + ROOT c.c128 = c128[2]{0} constant({(-1.79769e+308, 1.79769e+308), (1.79769e+308, -1.79769e+308)}) +} + )" }, }); @@ -1298,7 +1486,7 @@ class HloParameterizedParserTest protected: // Expects "ToString(ParseHloString(string)) == string", that is, parses the // string, asserts that it succeeded, stringifies the parsed module, and - // checks that the it equals the original string. + // checks that it equals the original string. void ExpectEqual() { const string& original = GetParam().module_string; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -1329,20 +1517,20 @@ TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); } TEST_P(HloParserTestShort, Run) { ExpectEqual(); } TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); } -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestLong, - ::testing::ValuesIn(CreateTestCases()), - TestDataToString); -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, - HloParserTestLongProto, - ::testing::ValuesIn(CreateTestCases()), - TestDataToString); -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestShort, - ::testing::ValuesIn(CreateShortTestCases()), - TestDataToString); -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, - HloParserTestShortProto, - ::testing::ValuesIn(CreateShortTestCases()), - TestDataToString); +INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestLong, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); +INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, + HloParserTestLongProto, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); +INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestShort, + ::testing::ValuesIn(CreateShortTestCases()), + TestDataToString); +INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, + HloParserTestShortProto, + ::testing::ValuesIn(CreateShortTestCases()), + TestDataToString); class HloParserTest : public ::testing::Test { protected: @@ -1419,7 +1607,7 @@ TEST_F(HloParserTest, MoreConstants) { ENTRY %SelectScalarS32True.v4 () -> s32[] { %constant.2 = pred[] constant(true) - %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4} + %constant.1 = s32[] constant(-42), sharding={devices=[2,2]1,2,3,4} %constant = s32[] constant(42) %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) } @@ -1462,7 +1650,7 @@ TEST_F(HloParserTest, LiteralDimensionsMismatch_2) { const string original = R"(HloModule some_2x3_module ENTRY %some_2x3 () -> f32[2,3] { - ROOT %constant = f32[2,3]{1,0} constant(f32[2,3] {1, 2, 3, 4, 5, 6}) + ROOT %constant = f32[2,3]{1,0} constant({1, 2, 3, 4, 5, 6}) } )"; @@ -1476,7 +1664,7 @@ TEST_F(HloParserTest, LiteralDimensionsMismatch_3) { const string original = R"(HloModule some_2x3x2_module ENTRY %some_2x3x2 () -> f32[2,3,2] { - ROOT %constant = f32[2,3,2]{2,1,0} constant(f32[2,3,2] {{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) + ROOT %constant = f32[2,3,2]{2,1,0} constant({{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) } )"; @@ -1501,6 +1689,37 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { "is out of range for literal's primitive type F16"); } +TEST_F(HloParserTest, ConstantBf16NoOverflow) { + // 65505 is in range for bf16. + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = bf16[] constant(-65505) + })"; + EXPECT_EQ(Status::OK(), ParseHloString(original).status()); +} + +TEST_F(HloParserTest, ConstantBf16Overflow) { + // 1e100 is out of range for bf16. + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = bf16[] constant(1e100) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "out of range"); +} + +TEST_F(HloParserTest, ConstantF16OverflowInSparseArray) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5]sparse{10} constant({[0]: 0, [1]: -65505}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "is out of range for literal's primitive type F16"); +} + TEST_F(HloParserTest, ConstantUnsignedUnderflow) { const string original = R"( HloModule ConstantUnsignedUnderflow_module @@ -1535,6 +1754,46 @@ TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) { EXPECT_NE(Status::OK(), result.status()); } +TEST_F(HloParserTest, ConstantC64Overflow) { + const string original = R"( + HloModule test_module + ENTRY test () -> c64[] { + ROOT c = c64[] constant((1e100, 0)) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + +TEST_F(HloParserTest, ConstantC64Underflow) { + const string original = R"( + HloModule test_module + ENTRY test () -> c64[] { + ROOT c = c64[] constant((0, -1e100)) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + +TEST_F(HloParserTest, ConstantF64Overflow) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f64[] constant(1.8e308) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + +TEST_F(HloParserTest, ConstantF64Underflow) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f64[] constant(-1.8e308) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + TEST_F(HloParserTest, ConstantWithExp) { const string original = R"(HloModule ConstantWithExp_module @@ -1550,6 +1809,19 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { // printed as "300". } +TEST_F(HloParserTest, ShortConstant) { + const string original = R"(HloModule ShortCOnstant_module + +ENTRY %ShortConstant.v4 () -> f32[67,89] { + ROOT %constant.1 = f32[67,89]{1,0} constant({...}) +} + +)"; + auto result = ParseHloString(original); + TF_EXPECT_OK(result.status()); + EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original); +} + TEST_F(HloParserTest, AttibutesAnyOrder) { const string original = R"(HloModule any_order_module @@ -1594,11 +1866,11 @@ TEST_F(HloParserTest, UnexpectedAttribute) { const string original = R"(HloModule unexpected_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, calls=%recv %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1611,11 +1883,11 @@ TEST_F(HloParserTest, MissingAttribute) { const string original = R"(HloModule missing_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(-2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token) + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0) %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1628,11 +1900,11 @@ TEST_F(HloParserTest, PredecessorUndefined) { const string original = R"(HloModule pre_not_found_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, control-predecessors={%done} %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1940,8 +2212,8 @@ TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) { TEST_F(HloParserTest, NontupleInfeed) { const string original = R"(HloModule nontuple_infeed: ENTRY nontuple_infeed { - token = token[] after-all() - ROOT infeed = pred[] infeed(token) + token0 = token[] after-all() + ROOT infeed = pred[] infeed(token0) })"; ExpectHasSubstr(ParseHloString(original).status().error_message(), "infeed must have a non-empty tuple shape"); @@ -2239,7 +2511,7 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] { %p = f32[2,2] parameter(0) - %constant.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.1 = f32[2,2] constant({{1, 2}, {3, 4}}) ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1) } )"; @@ -2249,7 +2521,218 @@ ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] { " with the shape of the operand instruction f32[2,2]{1,0}."); } -// custom call incompatible shape. +TEST_F(HloParserTest, OutOfRangeSparseIndex) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5]sparse{10} constant({[100]: 0}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Invalid sparse index"); +} + +TEST_F(HloParserTest, NegativeSparseIndex) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5]sparse{10} constant({-1: 0}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Invalid sparse index"); +} + +TEST_F(HloParserTest, SparseIndexWithRankTooLarge) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5]sparse{10} constant({[0, 0]: 0}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Invalid sparse index"); +} + +TEST_F(HloParserTest, SparseIndexWithRankTooSmall) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5, 5]sparse{10} constant({[0]: 0}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Invalid sparse index"); +} + +TEST_F(HloParserTest, ParseShapeStringR2F32) { + string shape_string = "f32[123,456]"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) { + string shape_string = "(f32[1572864],s8[5120,1024])"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), + ShapeUtil::MakeShape(S8, {5120, 1024})}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringNestedTuple) { + string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {1}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeOpaqueShape(), + ShapeUtil::MakeShape(F32, {3}), + }); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringWithLayout) { + string shape_string = "f32[123,456]{0,1}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringWithTilingLayout) { + // One tile. + string shape_string = "f32[123,456]{0,1:T(2,128)}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = + ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}, {Tile({2, 128})}); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Tile with negative dimension size for combining dimensions. + shape_string = "f32[123,456,789]{0,1,2:T(2, * , 128)}"; + TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); + expected = + ShapeUtil::MakeShapeWithLayout(F32, {123, 456, 789}, {0, 1, 2}, + {Tile({2, Tile::kCombineDimension, 128})}); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Two tiles. + shape_string = "bf16[123,456,789]{2,1,0:T(2,*,128)(2,1)}"; + TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); + expected = ShapeUtil::MakeShapeWithLayout( + BF16, {123, 456, 789}, {2, 1, 0}, + {Tile({2, Tile::kCombineDimension, 128}), Tile({2, 1})}); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Tile with element size in bits. + shape_string = "pred[123,456]{1,0:T(2,128)E(1)}"; + TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); + expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, + {Tile({2, 128})}, 1); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Element size in bits without tile. + shape_string = "pred[123,456]{1,0:E(1)}"; + TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); + expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, 1); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Wrong minor_to_major. + shape_string = "f32[123,456,789]{1:T(2, * , 128)}"; + auto result = ParseShape(shape_string); + ExpectHasSubstr(result.status().error_message(), + "Dimensions size is 3, but minor to major size is 1."); +} + +TEST_F(HloParserTest, ParseShapeStringWithSparseLayout) { + string shape_string = "f32[123,456]sparse{10}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseOpaqueType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("opaque[]")); + Shape expected = ShapeUtil::MakeOpaqueShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseTokenType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("token[]")); + Shape expected = ShapeUtil::MakeTokenShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseInvalidShapeString) { + string shape_strings[] = { + "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", + "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", + }; + for (const string& shape_string : shape_strings) { + StatusOr result = ParseShape(shape_string); + ASSERT_FALSE(result.ok()) << "shape: " << shape_string; + } +} + +TEST_F(HloParserTest, ParseDynamicArray) { + string shape_string = "f32[123,<=456]"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShape(F32, {123, 456}, {false, true}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseDynamicTuple) { + string shape_string = "(f32[42], u32[<=123,<=456])"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {42}), + ShapeUtil::MakeShape(U32, {123, 456}, {true, true})}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, NegativeParameterNumber) { + const string hlo_string = "par0 = f32[3,5] parameter(-1)"; + auto result = ParseHloString(hlo_string); + ASSERT_FALSE(result.status().ok()); + EXPECT_THAT(result.status().error_message(), + ::testing::HasSubstr("parameter number must be >= 0")); +} + +TEST_F(HloParserTest, WrongNumberOfParameterLeafBuffersInReplication) { + const string hlo_string = + "par0 = (f32[3,5], f32[]) parameter(0), " + "parameter_replication={true,false,true}"; + auto result = ParseHloString(hlo_string); + ASSERT_FALSE(result.status().ok()); + EXPECT_THAT(result.status().error_message(), + ::testing::HasSubstr("parameter has 2 leaf buffers, but " + "parameter_replication has 3 elements")); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h index 791b1a97b0b82edf19ff1588fd8d5d996ac0fef4..35dc9c0029f9871334cb500c6b71f0c86ab136d7 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_fix.h +++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -39,9 +40,36 @@ class HloPassFix : public Pass { int64 iteration_count = 0; int64 limit = std::max(static_cast(1000), module->instruction_count()); + VLOG(3) << "Running HloPassFix."; while (changed_this_iteration) { TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module)); changed |= changed_this_iteration; + VLOG(3) << "changed_this_iteration: " << changed_this_iteration; + ++iteration_count; + if (iteration_count == limit) { + LOG(ERROR) + << "Unexpectedly high number of iterations in HLO passes (" + << iteration_count + << ")\nIf compilation hangs here, please file a bug with XLA."; + } + } + return changed; + } + + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override { + bool changed = false; + bool changed_this_iteration = true; + int64 iteration_count = 0; + int64 limit = 1000; + for (const HloModule* module : module_group->modules()) { + limit = std::max(limit, module->instruction_count()); + } + VLOG(3) << "Running HloPassFix."; + while (changed_this_iteration) { + TF_ASSIGN_OR_RETURN(changed_this_iteration, + Pass::RunOnModuleGroup(module_group)); + changed |= changed_this_iteration; + VLOG(3) << "changed_this_iteration: " << changed_this_iteration; ++iteration_count; if (iteration_count == limit) { LOG(ERROR) diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 51177f24f5ee702be96fc8b4530ed38a5798109f..ae8c08cf1d16ad6738962f3be7c1b5512110b1d1 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -77,6 +77,11 @@ std::vector HloPassPipeline::GetEnabledPasses( auto repeated_field = debug_options.xla_disable_hlo_passes(); absl::flat_hash_set disabled_pass_names(repeated_field.begin(), repeated_field.end()); + if (debug_options.xla_disable_all_hlo_passes()) { + VLOG(1) << "*All* passes disabled by --xla_disable_all_hlo_passes."; + return {}; + } + if (!disabled_pass_names.empty()) { VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " << absl::StrJoin(disabled_pass_names, ", "); @@ -84,7 +89,7 @@ std::vector HloPassPipeline::GetEnabledPasses( std::vector enabled_passes; for (auto& pass : passes_) { - if (disabled_pass_names.count(string(pass->name())) == 0) { + if (!disabled_pass_names.contains(pass->name())) { enabled_passes.push_back(pass.get()); } } diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.cc b/tensorflow/compiler/xla/service/hlo_profile_printer.cc index 5eb707a957e49d86cdb2f72b72ce750bf29b8fd2..9cc202aa9f5fe5a20a9da05251ea811137ccaadb 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer.cc +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_profile_printer.h" +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" @@ -34,11 +35,10 @@ string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data, for (const HloComputationInfo& computation_info : hlo_profile_printer_data.computation_infos()) { const auto& instruction_infos = computation_info.instruction_infos(); - bool any_instruction_profiled = - std::any_of(instruction_infos.begin(), instruction_infos.end(), - [&](const HloInstructionInfo& instruction_info) { - return counters[instruction_info.profile_index()] != 0; - }); + bool any_instruction_profiled = absl::c_any_of( + instruction_infos, [&](const HloInstructionInfo& instruction_info) { + return counters[instruction_info.profile_index()] != 0; + }); if (!any_instruction_profiled) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 981d06ce101644ecce587c4bd2f7a12c8edf6548..3a9ee57e5551ae5b608f02d9f8bd0428ff16db13 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -39,6 +39,7 @@ HloProto MakeHloProto(const HloModule& module) { StatusOr> CreateModuleFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { + VLOG(4) << proto.ShortDebugString(); TF_ASSIGN_OR_RETURN(std::unique_ptr module, HloModule::CreateFromProto(proto, module_config)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 4aa8067752481ffab29e1a573ffa49d4aa046f1f..b7f507b1184dbe021effc1102a68040286480ed2 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -49,7 +49,7 @@ void HloReachabilityMap::SetReachabilityToUnionHelper( absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector) { // If instruction is part of inputs, don't reset the bit_vector. - if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { + if (!absl::c_linear_search(inputs, instruction)) { bit_vector->SetToZero(); } bit_vector->Set(GetIndex(instruction)); @@ -77,28 +77,51 @@ std::unique_ptr HloReachabilityMap::Build( const HloComputation* computation) { const auto& all = computation->MakeInstructionPostOrder(); auto result = absl::make_unique(all); - auto channel_dependency_map = computation->ComputeChannelDependencies(); + auto channel_group = computation->ComputeChannelDependencies(); - std::vector inputs; for (const HloInstruction* hlo : all) { - inputs.assign(hlo->operands().begin(), hlo->operands().end()); - inputs.insert(inputs.end(), hlo->control_predecessors().begin(), - hlo->control_predecessors().end()); + std::vector inputs; + const auto add_input = [&channel_group, &inputs](HloInstruction* input) { + inputs.push_back(input); + if (input->opcode() == HloOpcode::kAllReduce && input->all_reduce_id()) { + auto it = channel_group.find(*input->all_reduce_id()); + if (it != channel_group.end()) { + inputs.insert(inputs.end(), it->second.begin(), it->second.end()); + } + } + }; + + const auto add_dependencies = [&add_input](const HloInstruction* hlo) { + for (HloInstruction* operand : hlo->operands()) { + add_input(operand); + } + for (HloInstruction* predecessor : hlo->control_predecessors()) { + add_input(predecessor); + } + }; + + add_dependencies(hlo); switch (hlo->opcode()) { case HloOpcode::kRecvDone: { - auto it = channel_dependency_map.find(hlo->channel_id()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); + auto it = channel_group.find(hlo->channel_id()); + if (it != channel_group.end()) { + for (HloInstruction* channel : it->second) { + if (channel->opcode() == HloOpcode::kSend) { + add_input(channel); + } + } } break; } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kAllReduce: { auto all_reduce_id = hlo->all_reduce_id(); if (all_reduce_id) { - auto it = channel_dependency_map.find(all_reduce_id.value()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); + auto it = channel_group.find(all_reduce_id.value()); + if (it != channel_group.end()) { + for (HloInstruction* all_reduce : it->second) { + add_dependencies(all_reduce); + } } } break; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 48add75523f02005c70bc6baf69a6b7d5aa4f7ef..a175e4643de2ac6ce07ac00da914d7ab7acca541 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -57,13 +57,22 @@ using ::tensorflow::strings::HumanReadableNumBytes; // Returns true if the given instruction is rematerializable. bool IsRematerializable(const HloInstruction* instruction) { + if (instruction->opcode() == HloOpcode::kCopy) { + if (LayoutUtil::Equal(instruction->shape().layout(), + instruction->operand(0)->shape().layout())) { + // Don't rematerialize copies added by copy insertion (layout doesn't + // change). + return false; + } + } + // Don't rematerialize instructions with side effects or instructions which // cannot be cloned safely. switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kConstant: case HloOpcode::kConditional: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kCustomCall: case HloOpcode::kParameter: case HloOpcode::kWhile: @@ -179,7 +188,8 @@ class InstructionList { Item* CreateItem(HloInstruction* inst) { Item* item = new Item; item->instruction = inst; - CHECK(item_map_.insert({inst, item}).second) << "inserting inst twice"; + CHECK(item_map_.insert({inst, item}).second) + << "inserting inst twice " << inst->name(); return item; } @@ -235,8 +245,7 @@ class InstructionList { } // Now scan forwards until we find one of the before_instructions. - while (std::find(before_instructions.begin(), before_instructions.end(), - min_position_item) == before_instructions.end()) { + while (!absl::c_linear_search(before_instructions, min_position_item)) { min_position_item = min_position_item->next; } return InsertBefore(to_insert, min_position_item); @@ -302,7 +311,7 @@ ItemList GetUsers(const InstructionList& instruction_list, // A buffer may be used by the instruction via more than one alias. For // example, a buffer which appears in more than one element of a tuple. Item* user_item = instruction_list.GetItem(user); - if (std::find(users.begin(), users.end(), user_item) == users.end()) { + if (!absl::c_linear_search(users, user_item)) { users.push_back(user_item); } } @@ -418,11 +427,12 @@ class MemoryUsageTracker { // the given uses. Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item, ItemList&& rematerialized_uses) { - CHECK(original_buffer.defining_instruction->placed); - CHECK(!original_buffer.has_indirect_uses); - CHECK(!original_buffer.live_out); + CHECK(original_buffer.defining_instruction->placed) + << original_buffer.defining_instruction->instruction->name(); + CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString(); + CHECK(!original_buffer.live_out) << original_buffer.ToString(); for (Item* use : rematerialized_uses) { - CHECK(!use->placed); + CHECK(!use->placed) << use->instruction->name(); } return NewBuffer(remat_item, original_buffer.size, std::move(rematerialized_uses), /*live_out=*/false, @@ -456,8 +466,7 @@ class MemoryUsageTracker { return false; } const BufferIdList& in_progress_uses = in_progress_item_->buffers_used; - return std::find(in_progress_uses.begin(), in_progress_uses.end(), - buffer_id) != in_progress_uses.end(); + return absl::c_linear_search(in_progress_uses, buffer_id); } // Returns whether the given instruction is live at the current program @@ -535,8 +544,7 @@ MemoryUsageTracker::MemoryUsageTracker( bool unused; for (Item* user_item : GetUsers(instruction_list_, logical_buffer, points_to_analysis, &unused)) { - if (std::find(buffer->users.begin(), buffer->users.end(), - user_item) == buffer->users.end()) { + if (!absl::c_linear_search(buffer->users, user_item)) { buffer->users.push_back(user_item); buffer->unfinished_user_count++; user_item->buffers_used.push_back(buffer->id); @@ -677,8 +685,8 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, << ", remat_instruction = " << remat_item->instruction->name(); TF_RET_CHECK(in_progress_item_ != nullptr); - TF_RET_CHECK(original_item->placed); - TF_RET_CHECK(!remat_item->placed); + TF_RET_CHECK(original_item->placed) << original_item->instruction->name(); + TF_RET_CHECK(!remat_item->placed) << remat_item->instruction->name(); // Construct the list of buffers used and defined by the rematerialization. remat_item->buffers_used = original_item->buffers_used; @@ -707,7 +715,7 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, ItemList unplaced_users; for (Item* user : old_buffer.users) { if (user->placed) { - CHECK(IsFinished(user)); + CHECK(IsFinished(user)) << user->instruction->name(); placed_users.push_back(user); } else { unplaced_users.push_back(user); @@ -784,8 +792,7 @@ bool MemoryUsageTracker::Check() const { for (const Buffer& buffer : buffers_) { if (buffer.defining_instruction->instruction == instruction) { - CHECK(std::find(defined_buffers.begin(), defined_buffers.end(), - buffer.id) != defined_buffers.end()) + CHECK(absl::c_linear_search(defined_buffers, buffer.id)) << "Instruction " << instruction->name() << " defined buffers is missing: " << buffer.ToString(); } @@ -808,8 +815,7 @@ bool MemoryUsageTracker::Check() const { int64 unfinished_uses = 0; for (Item* user : buffer.users) { const BufferIdList& used_buffers = user->buffers_used; - CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) != - used_buffers.end()) + CHECK(absl::c_linear_search(used_buffers, buffer.id)) << "Instruction " << user->instruction->name() << " used buffers is missing " << buffer.ToString(); if (!IsFinished(user)) { @@ -836,10 +842,10 @@ int64 RematerializationCost(const HloInstruction* instruction, // If none of the users of 'instruction' have been placed in the sequence (as // tracked by memory_tracker), then rematerialization of 'instruction' is a // zero-cost move of 'instruction' in the sequence. - if (!std::any_of(instruction->users().begin(), instruction->users().end(), - [&memory_tracker](const HloInstruction* inst) { - return memory_tracker.IsPlaced(inst); - })) { + if (!absl::c_any_of(instruction->users(), + [&memory_tracker](const HloInstruction* inst) { + return memory_tracker.IsPlaced(inst); + })) { return 0; } @@ -1094,7 +1100,7 @@ StatusOr HloRematerialization::RematerializeComputation( Item* successor_item = instruction_list.GetItem(successor); // Assert to make sure we never remat an operation with control // successor already placed. - CHECK(!successor_item->placed); + CHECK(!successor_item->placed) << successor_item->instruction->name(); place_before.push_back(successor_item); } instruction_list.InsertBeforeInstructions(remat_item, place_before); @@ -1164,7 +1170,7 @@ StatusOr HloRematerialization::RematerializeComputation( // Verify some invariants on the memory tracker. CHECK_EQ(memory_tracker.memory_usage(), 0); for (auto* instruction : computation->instructions()) { - CHECK(memory_tracker.IsPlaced(instruction)); + CHECK(memory_tracker.IsPlaced(instruction)) << instruction->name(); } VLOG(1) << "In computation " << computation->name() << " rematerialized " diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 22c3c40a93a1ddcd36659483fcc79fede32dd2c3..102a360ad8116d8781baf9cb7627a920f4a687c4 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -499,6 +499,52 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { EXPECT_THAT(add_4->operand(0), op::Broadcast(param)); } +TEST_F(HloRematerializationTest, CopyNotRematerialized) { + // Test that copies are not rematerialized. + auto module = CreateNewVerifiedModule(); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1024_shape_, "param")); + + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kCopy, param)); + + auto negate_a_1 = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, copy)); + + auto negate_a_2 = builder.AddInstruction(HloInstruction::CreateUnary( + vec1024_shape_, HloOpcode::kNegate, negate_a_1)); + + auto negate_b_1 = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, copy)); + + auto negate_b_2 = builder.AddInstruction(HloInstruction::CreateUnary( + vec1024_shape_, HloOpcode::kNegate, negate_b_1)); + + builder.AddInstruction(HloInstruction::CreateTuple({negate_a_2, negate_b_2})); + + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/1 * 1024, module.get())); + + auto count_copies = [](const HloComputation* computation) { + int64 copy_count = 0; + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + copy_count++; + } + } + return copy_count; + }; + EXPECT_TRUE(changed); + + EXPECT_EQ(count_copies(entry_computation), 1); +} + class IndirectUseTest : public HloRematerializationTest, public ::testing::WithParamInterface {}; @@ -588,8 +634,8 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { } } -INSTANTIATE_TEST_CASE_P(IndirectUseTestInstantiation, IndirectUseTest, - ::testing::Values(true, false)); +INSTANTIATE_TEST_SUITE_P(IndirectUseTestInstantiation, IndirectUseTest, + ::testing::Values(true, false)); } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 5a9b820a9d7f58695383b21c9e2126cf98970c83..5a5401e351384867016a3a9addfd43d57091848c 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -168,6 +168,35 @@ StatusOr HloRunner::Execute(std::unique_ptr module, /*profile=*/profile); } +StatusOr HloRunner::Execute( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile) { + TF_ASSIGN_OR_RETURN(std::vector argument_buffers, + TransferLiteralsToDevice(arguments)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + ExecuteWithDeviceBuffers( + /*executable=*/executable.get(), + /*arguments=*/argument_buffers, + /*profile=*/profile)); + return TransferLiteralFromDevice(result); +} + +StatusOr HloRunner::Execute(std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile) { + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(&argument); + } + return Execute( + /*module=*/std::move(executable), + /*arguments=*/argument_pointers, + /*profile=*/profile); +} + StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, const absl::Span arguments, bool run_hlo_passes, @@ -206,7 +235,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( } StatusOr HloRunner::ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile) { // Get service run options. @@ -225,7 +254,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( } StatusOr HloRunner::ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile) { std::vector argument_pointers; @@ -383,9 +412,7 @@ ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice( if (device_assignment != nullptr) { run_options.set_device_assignment(device_assignment); } - return ServiceExecutableRunOptions( - run_options, backend().StreamBorrower(), - /*xla_intra_op_thread_pool=*/backend().eigen_intra_op_thread_pool()); + return ServiceExecutableRunOptions(run_options, backend().StreamBorrower()); } Backend& HloRunner::backend() { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index bb792cf8c9825ff67ca33bbcf2c3c32b1a0ecb85..098989cd4c78fb5ad57cd6700fbf99c50064f225 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -60,7 +60,7 @@ class HloRunner { // The number of times the infeed literal should be fed to the HLO module. // For a clean exit, this should match the iterations-per-loop parameter // used when generating the HLO module proto (that is usually the main - // while bounary counter). A value higher then iterations-per-loop would + // while boundary counter). A value higher then iterations-per-loop would // lead to infeed threads feeding to a gone computation, while a lower // value would trigger a stuck ExecuteReplicated() call (the computation // will be trying to infeed data which will never come). @@ -124,6 +124,14 @@ class HloRunner { bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr Execute(std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile = nullptr); + + StatusOr Execute(std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile = nullptr); + // As Execute(), but accepts and returns device buffers instead of host // buffers. StatusOr ExecuteWithDeviceBuffers( @@ -136,13 +144,16 @@ class HloRunner { const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + // In the following two calls, "executable" is not a unique_ptr to allow + // reuse of the Executable. This call may update the profile information in + // *executable. StatusOr ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile = nullptr); StatusOr ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile = nullptr); diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index 8f6eb974c5179b420c8f961393ca923e0a3b3530..e75373501cffac6a736be89e9f6139b6ff2cdbc1 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -140,7 +140,7 @@ Status HloSchedule::UpdateComputationSchedule( std::queue worklist; for (HloInstruction* instruction : computation->instructions()) { - if (ids_in_schedule.count(instruction->unique_id()) == 0) { + if (!ids_in_schedule.contains(instruction->unique_id())) { // This is a newly added instruction which is not in the schedule. if (instruction->operands().empty()) { worklist.push(instruction); @@ -204,7 +204,7 @@ Status HloSchedule::Update() { std::vector nonfusion_computations = module_->MakeNonfusionComputations(); for (const HloComputation* computation : nonfusion_computations) { - TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + TF_RET_CHECK(sequences_.contains(computation->unique_id())) << "Computation " << computation->name() << " not in HloSchedule."; } if (sequences_.size() > nonfusion_computations.size()) { @@ -215,7 +215,7 @@ Status HloSchedule::Update() { nonfusion_computations_ids.insert(computation->unique_id()); } for (auto it = sequences_.begin(); it != sequences_.end();) { - if (nonfusion_computations_ids.count(it->first) == 0) { + if (!nonfusion_computations_ids.contains(it->first)) { sequences_.erase(it++); } else { ++it; @@ -244,7 +244,7 @@ Status HloSchedule::Verify() const { << "Schedule has " << sequences_.size() << " sequences, but module has " << nonfusion_computations.size() << " non-fusion computations"; for (const HloComputation* computation : nonfusion_computations) { - TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + TF_RET_CHECK(sequences_.contains(computation->unique_id())) << "Computation " << computation->name() << " missing from HLO schedule."; } @@ -268,7 +268,7 @@ Status HloSchedule::Verify() const { << instruction_position.size() << " instructions, expected " << computation->instruction_count(); for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK(instruction_position.count(instruction) == 1) + TF_RET_CHECK(instruction_position.contains(instruction)) << "Instruction " << instruction->name() << " is not in schedule"; } diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 486ddbf499de80c634bc497158cd79ca066cc866..a5f54ae2c33259d080631061dff9ae40b41495dc 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -110,7 +110,7 @@ class HloSchedule { // Returns true if the schedule has a sequence for the given computation. bool is_computation_scheduled(const HloComputation* computation) const { - return sequences_.count(computation->unique_id()) == 1; + return sequences_.contains(computation->unique_id()); } // Updates the schedule such that it is (again) a valid schedule for the diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 70a860c356ca2fb1c4c973ea3d96c50fabc2c7c2..f1d7e60f2b5a68408f6d428a0ec47fba3c9c4f12 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/overflow_util.h" @@ -30,7 +31,7 @@ HloSharding HloSharding::AssignDevice(int64 device_id) { } HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { - CHECK_EQ(1, ShapeUtil::Rank(input_shape)); + CHECK_EQ(1, input_shape.rank()); CHECK_GT(num_tiles, 1); std::vector dimensions(1, num_tiles); Array assignment(dimensions); @@ -57,7 +58,7 @@ HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { HloSharding HloSharding::Tuple(const Shape& tuple_shape, absl::Span shardings) { - CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape); for (auto& sharding : shardings) { CHECK(!sharding.IsTuple()) << sharding.ToString(); } @@ -70,7 +71,7 @@ HloSharding HloSharding::Tuple(const Shape& tuple_shape, HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, const HloSharding& sharding) { - CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape); CHECK(!sharding.IsTuple()) << sharding.ToString(); int64 leaf_count = RequiredLeaves(tuple_shape); std::vector flattened_list; @@ -80,7 +81,7 @@ HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, HloSharding HloSharding::Single(const Shape& shape, const HloSharding& sharding) { - return ShapeUtil::IsTuple(shape) ? SingleTuple(shape, sharding) : sharding; + return shape.IsTuple() ? SingleTuple(shape, sharding) : sharding; } string HloSharding::ToString() const { @@ -95,24 +96,23 @@ string HloSharding::ToString() const { if (replicated_) { return "{replicated}"; - } else if (maximal_) { + } + if (maximal_) { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); - } else { - return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), - "]", StrJoin(tile_assignment_, ","), "}"); } + return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]", + StrJoin(tile_assignment_, ","), "}"); } bool HloSharding::UsesDevice(int64 device) const { if (IsTuple()) { - return std::any_of( - tuple_elements_.begin(), tuple_elements_.end(), - [&](const HloSharding& s) { return s.UsesDevice(device); }); + return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) { + return s.UsesDevice(device); + }); } const auto& devices = tile_assignment_; - return replicated_ || - std::find(devices.begin(), devices.end(), device) != devices.end(); + return replicated_ || absl::c_linear_search(devices, device); } std::map HloSharding::UsedDevices(int64* count) const { @@ -269,7 +269,7 @@ int64 HloSharding::GetUniqueDevice() const { } Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { - if (!ShapeUtil::IsTuple(shape)) { + if (!shape.IsTuple()) { return tensorflow::errors::InvalidArgument( StrCat("Sharding is tuple-shaped but validation shape is not.")); } @@ -305,7 +305,7 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { Status HloSharding::ValidateNonTuple(const Shape& shape, int64 num_devices) const { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { return tensorflow::errors::InvalidArgument( StrCat("Validation shape is a tuple but sharding is not.")); } @@ -316,7 +316,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, // All tile assignments must be less than the number of available cores and // unique. Status status = Status::OK(); - std::set seen_cores; + absl::flat_hash_set seen_cores; tile_assignment_.Each( [&](absl::Span indices, int32 core) { // Don't overwrite a bad status, so we report the first error. @@ -324,12 +324,12 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, if (core >= num_devices) { status = tensorflow::errors::InvalidArgument(StrCat( "core ", core, " > ", num_devices, " in tile assignment")); - } else if (seen_cores.count(core) != 0) { + } else if (seen_cores.contains(core)) { status = tensorflow::errors::InvalidArgument( StrCat("core ", core, " is not unique in tile assignment")); } + seen_cores.insert(core); } - seen_cores.insert(core); }); if (!status.ok()) { return status; @@ -340,14 +340,14 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, } // The tile assignment tensor must have the same rank as the input. - if (ShapeUtil::Rank(shape) != tile_assignment_.num_dimensions()) { + if (shape.rank() != tile_assignment_.num_dimensions()) { return tensorflow::errors::InvalidArgument( "Number of tile assignment dimensions is different to the input rank. " "sharding=", ToString(), ", input_shape=", ShapeUtil::HumanString(shape)); } - // The correct constructor have to be used to create tile maximal shardings. + // The correct constructor has to be used to create tile maximal shardings. if (tile_assignment_.num_elements() == 1) { return tensorflow::errors::InvalidArgument( "Tile assignment only contains a single device. If a replicated " @@ -437,8 +437,8 @@ Shape HloSharding::TileShape(const Shape& shape) const { } Shape result_shape = shape; for (int64 i = 0; i < shape.dimensions_size(); ++i) { - (*result_shape.mutable_dimensions())[i] = - CeilOfRatio(shape.dimensions(i), tile_assignment_.dim(i)); + result_shape.set_dimensions( + i, CeilOfRatio(shape.dimensions(i), tile_assignment_.dim(i))); } return result_shape; } @@ -455,7 +455,7 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, } sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx}); } - if (ShapeUtil::IsTuple(*sub_shape)) { + if (sub_shape->IsTuple()) { auto begin_it = tuple_elements_.begin() + sharding_index; std::vector sub_shardings( begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape)); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 9775505f8608ced3e33abe376f4922cc6a972726..dd57ea83f1cb33aa052facb607bc040d2e708633 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -101,8 +101,8 @@ class HloSharding { if (!IsTuple()) { return replicated_; } - return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), - [](const HloSharding& s) { return s.IsReplicated(); }); + return absl::c_all_of( + tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); }); } // Returns true if the tile size is the same as the input size. @@ -110,14 +110,15 @@ class HloSharding { if (!IsTuple()) { return maximal_; } - return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), - [](const HloSharding& s) { return s.IsTileMaximal(); }); + return absl::c_all_of(tuple_elements_, [](const HloSharding& s) { + return s.IsTileMaximal(); + }); } // Returns true if the sharding defines an operation on the given device. bool UsesDevice(int64 device) const; - // Retrieves an histogram of the devices used by the sharding. The returned + // Retrieves a histogram of the devices used by the sharding. The returned // map has the device number as key, and the occurrence count as value. // If a sharding does not have a device, it will not be incuded in the // histogram. The count argument, if not nullptr, will receive the total @@ -259,6 +260,19 @@ class HloSharding { bool replicated_; bool maximal_; bool tuple_; + // This field is only used if replicated_ is false. If maximal_ is true, then + // the field contains a rank 1 array with a single element, which is the + // device the HLO is assigned to. If maximal_ is false, the field contains an + // array with the same rank as the corresponding HLO. The dimension sizes of + // the array describe the number of ways the HLO is partitioned along each + // dimension. The values of the array specify which device each tile of + // the HLO is assigned to. The index of each value determines which tile it + // takes. + // For example, {{{2, 3}}, {{5, 7}}} (whose ToString representation is + // "{devices=[2,1,2]2,3,5,7}"), means that dimension 1 is split two way and + // dimension 3 is split 2 way. Core 5, whose index is [2,1,1] will take the + // tile that contains the 2nd half of dimension 1 and the 1st half of + // dimension 3. Array tile_assignment_; // Only non-empty when tuple_ is true. If a tuple is empty then one entry is // present for the root. This is a flattened list of all the leaf shardings in diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index f5061304456e04ab40448861343ef201c9450dcf..094d98bc6e54028557f6d38cd165bf34e1fb8c46 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -99,7 +99,7 @@ std::vector LocatePassThroughDomainLinks( << "Instruction is not a kDomain: " << instruction->ToString(); for (HloInstruction* user : instruction->users()) { if (user->opcode() == HloOpcode::kDomain && - domain.exit_domains.count(user) != 0) { + domain.exit_domains.contains(user)) { pass_through.emplace_back(user, instruction); VLOG(2) << "Found passthrough domain link:"; VLOG(2) << " " << user->ToString(); @@ -234,7 +234,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, if (instruction->users().empty()) { // No sharding from users, use domain_sharding, after checking // compatibility. - TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()) && + TF_RET_CHECK(instruction->shape().IsTuple() && ShapeUtil::GetLeafCount(instruction->shape()) == domain_sharding.tuple_elements().size()); instruction->set_sharding(domain_sharding); @@ -253,7 +253,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice)); for (HloInstruction* user : instruction->users()) { if (user->opcode() == HloOpcode::kDomain && - domain.exit_domains.count(user) > 0) { + domain.exit_domains.contains(user)) { // If a user is a domain and it is registered in the domain exits, then // the instruction sharding is taken directly from the domain, and no // further users need to be visited. @@ -266,7 +266,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, AssignmentKind sub_assigned = AssignmentKind::kUnassigned; TF_ASSIGN_OR_RETURN(ShapeTree user_sharding_tree, GetShardingTreeFromUser(*instruction, *user)); - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { // For tuple-shaped instructions collect individual tuple subshardings // from the uses, and then combine them into the tuple sharding. // If the user is a GTE its sharding concerns only the subtree of @@ -298,7 +298,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, } if (assigned == AssignmentKind::kAssigned) { - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { instruction->set_sharding(HloSharding::Tuple(sharding_tree)); } else { TF_RET_CHECK(sharding_tree.leaf_count() == 1); @@ -361,7 +361,7 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain, // kUnassignedDevice. Indeed in case of doubt it is better to leave the // entire tuple unassigned, and let the device placer decide for it. if (instruction->sharding().UsesDevice(kUnassignedDevice)) { - TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape())) + TF_RET_CHECK(instruction->shape().IsTuple()) << "Only tuples can have kUnassignedDevice sub shardings"; instruction->clear_sharding(); } diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 80634677e78e4a35dcb9bf7de018a88122c3c030..9e234e025586ff14f99da73afc5610c627303a36 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -84,7 +84,7 @@ TEST_F(HloShardingTest, Tile) { } { - // Test should fail because of more devices used then `num_device`. + // Test should fail because of more devices used than `num_device`. HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}), /*num_devices=*/2)); diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc deleted file mode 100644 index 487653344976a10e18ba667085525ba1ecbb8612..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ /dev/null @@ -1,243 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -LIcensed under the Apache License, Version 2.0 (the "License"); -You may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" - -namespace xla { -namespace hlo_graph_dumper { -namespace { - -using absl::StrAppend; -using absl::StrCat; -using tensorflow::GraphDef; -using tensorflow::NodeDef; -using tensorflow::TensorShapeProto; - -string GetOpDefName(const HloInstruction* instruction) { - string name = StrCat("hlo-", HloOpcodeString(instruction->opcode())); - tensorflow::str_util::TitlecaseString(&name, "-"); // non-absl ok - name.erase(std::remove(name.begin(), name.end(), '-'), name.end()); - - if (instruction->opcode() == HloOpcode::kFusion) { - string fusion_name = ToString(instruction->fusion_kind()); - StrAppend(&name, absl::string_view(fusion_name).substr(1)); - } - return name; -} - -TensorShapeProto GetTensorShape(const HloInstruction* instruction) { - TensorShapeProto tensor_shape; - const Shape& shape = instruction->shape(); - for (auto dim : shape.dimensions()) { - tensor_shape.add_dim()->set_size(dim); - } - return tensor_shape; -} - -string GetDeviceName(int device) { return StrCat("/device/XLA:", device); } - -void CleanNodeName(string* name) { - name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); - const string chars_to_replace = "<>[]"; - auto pred = [&](char c) { - return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) != - chars_to_replace.end(); - }; - std::replace_if(name->begin(), name->end(), pred, '_'); -} - -} // namespace - -HloTfGraphBuilder::HloTfGraphBuilder(const DebugOptions& debug_options) - : debug_options_(debug_options) {} - -Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { - VLOG(2) << "Adding computation " << computation.name(); - for (auto embedded : computation.MakeEmbeddedComputationsList()) { - for (auto* instruction : embedded->instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(instruction)); - } - } - for (auto* instruction : computation.instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(instruction)); - } - return Status::OK(); -} - -const GraphDef& HloTfGraphBuilder::GetGraphDef() const { return graph_def_; } - -const string& HloTfGraphBuilder::GetNodeNameForInstruction( - const HloInstruction* instruction) { - if (ContainsKey(instruction_to_node_name_, instruction)) { - return instruction_to_node_name_[instruction]; - } - auto append = [](string* str, const string& other) { - if (str->empty()) { - *str = other; - } else if (!other.empty()) { - StrAppend(str, "/", other); - } - }; - string node_name; - if (debug_options_.xla_hlo_tfgraph_device_scopes()) { - auto device = instruction->sharding_unique_device(); - if (device) { - node_name = StrCat("dev", *device); - } - } - // If an instruction is fused, put it in the subgraph of the fusion; - // otherwise, put it in the computation subgraph. - const HloComputation* computation = instruction->parent(); - if (computation->IsFusionComputation()) { - append(&node_name, - GetNodeNameForInstruction(computation->FusionInstruction())); - } else { - append(&node_name, computation->name()); - if (!instruction->metadata().op_name().empty()) { - // Always make computations contain TF ops but not the other way around. - append(&node_name, instruction->metadata().op_name()); - } - } - string instruction_name = instruction->name(); - if (instruction->opcode() == HloOpcode::kParameter) { - StrAppend(&instruction_name, ".", instruction->parameter_number()); - } - append(&node_name, instruction_name); - CleanNodeName(&node_name); - auto ret = - instruction_to_node_name_.insert(std::make_pair(instruction, node_name)); - CHECK(ret.second); - return ret.first->second; -} - -void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, - NodeDef* node_def) const { - auto& attrs = *node_def->mutable_attr(); - - // Set the number of arguments for instructions that have variadic operands. - if (HloOpcodeIsVariadic(instruction->opcode())) { - tensorflow::AttrValue attr_value; - attr_value.set_i(instruction->operands().size()); - attrs["arg_num"] = attr_value; - } - - // Set the node type. - attrs["type"].set_s( - xla::PrimitiveType_Name(instruction->shape().element_type())); - - // Set the framework op (e.g. Tensorflow op) that generated this XLA op. - attrs["tf_op_type"].set_s(instruction->metadata().op_type()); - attrs["tf_op_name"].set_s(instruction->metadata().op_name()); - - // Set the shape of the output tensor. "_output_shapes" is a special attribute - // name used by Tensorboard for shapes of output tensors. - tensorflow::AttrValue shapes; - *shapes.mutable_list()->add_shape() = GetTensorShape(instruction); - attrs["_output_shapes"] = shapes; - - // Set the layout. - if (LayoutUtil::HasLayout(instruction->shape())) { - string layout_string; - if (ShapeUtil::IsTuple(instruction->shape())) { - // For tuples, emit the full shape because the layout of a tuple is not - // represented in a single Layout field. - layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); - } else { - layout_string = StrCat( - "{", - absl::StrJoin(LayoutUtil::MinorToMajor(instruction->shape()), ","), - "}"); - } - attrs["layout"].set_s(layout_string); - } - - // Set op-specific attributes. - switch (instruction->opcode()) { - case HloOpcode::kConcatenate: - case HloOpcode::kBroadcast: - case HloOpcode::kReduce: - case HloOpcode::kReverse: - case HloOpcode::kTranspose: - for (auto dim : instruction->dimensions()) { - attrs["dims"].mutable_list()->add_i(dim); - } - break; - case HloOpcode::kGetTupleElement: - attrs["index"].set_i(instruction->tuple_index()); - break; - case HloOpcode::kRng: - attrs["dist"].set_s( - RandomDistribution_Name(instruction->random_distribution())); - break; - case HloOpcode::kConstant: - if (ShapeUtil::IsScalar(instruction->shape())) { - attrs["value"].set_s(instruction->literal().GetAsString({})); - } - break; - case HloOpcode::kCustomCall: - attrs["custom_call_target"].set_s(instruction->custom_call_target()); - break; - case HloOpcode::kSend: - case HloOpcode::kRecv: - attrs["channel_id"].set_i(instruction->channel_id()); - break; - default: - break; - } -} - -Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { - if (!visited_instructions_.insert(instruction).second) { - // Skip instructions that have already been added. - return Status::OK(); - } - - NodeDef* node_def = graph_def_.add_node(); - node_def->set_name(GetNodeNameForInstruction(instruction)); - node_def->set_op(GetOpDefName(instruction)); - - auto device = instruction->sharding_unique_device(); - if (device) { - node_def->set_device(GetDeviceName(*device)); - } - SetNodeAttrs(instruction, node_def); - if (instruction->opcode() == HloOpcode::kFusion) { - for (auto* fused_instruction : instruction->fused_instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(fused_instruction)); - } - } - // Add all edges including control edges. - for (unsigned i = 0; i < instruction->operands().size(); ++i) { - *node_def->add_input() = GetNodeNameForInstruction(instruction->operand(i)); - } - // Called computations are control dependencies. - for (const auto* called_computation : instruction->called_computations()) { - *node_def->add_input() = StrCat( - "^", GetNodeNameForInstruction(called_computation->root_instruction())); - } - return Status::OK(); -} - -} // namespace hlo_graph_dumper -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h deleted file mode 100644 index c4876b852e32d34693202f4023aa20ad2b301ffd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ - -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" - -namespace xla { -namespace hlo_graph_dumper { - -// This constructs a tensorflow graph for HLO computations. -class HloTfGraphBuilder { - public: - HloTfGraphBuilder(const DebugOptions& debug_options = DebugOptions()); - - // Adds a computation to the graph. - Status AddComputation(const HloComputation& computation); - - const tensorflow::GraphDef& GetGraphDef() const; - - private: - // Gets the node name of an instruction. The node name is hierarchical. For - // example, if an instruction is fused, it will be put in a subgraph of the - // fusion instruction. - const string& GetNodeNameForInstruction(const HloInstruction* instruction); - - void SetNodeAttrs(const HloInstruction* instruction, - tensorflow::NodeDef* node_def) const; - - Status AddInstruction(const HloInstruction* instruction); - - DebugOptions debug_options_; - tensorflow::GraphDef graph_def_; - // This records instructions that have been visited. - std::unordered_set visited_instructions_; - // A cache that maps instruction to the node name. - std::unordered_map instruction_to_node_name_; -}; - -} // namespace hlo_graph_dumper -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc deleted file mode 100644 index 1e2b31a1f2bb4865faafc3d14e2b194e3aa171a1..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ /dev/null @@ -1,183 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" - -namespace xla { -namespace hlo_graph_dumper { -namespace { - -using ::tensorflow::GraphDef; - -class HloTfGraphBuilderTest : public HloTestBase { - protected: - HloTfGraphBuilderTest() {} - HloTfGraphBuilder generator_; - - // Create a computation which takes a scalar and returns its negation. - std::unique_ptr CreateNegateComputation() { - auto builder = HloComputation::Builder("Negate"); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); - return builder.Build(); - } - - // Creates a computation which calls map with the given computation. - std::unique_ptr CreateMapComputation( - HloComputation *map_computation) { - auto builder = HloComputation::Builder("Map"); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - builder.AddInstruction( - HloInstruction::CreateMap(r0f32_, {param}, map_computation)); - return builder.Build(); - } - Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {}); -}; - -static const tensorflow::AttrValue &GetNodeAttr(const tensorflow::NodeDef &node, - const string &attr_name) { - auto attr = node.attr().find(attr_name); - CHECK(attr != node.attr().end()); - return attr->second; -} - -TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { - auto builder = HloComputation::Builder("Concatenate"); - Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}); - auto param_1 = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param0")); - auto param_2 = builder.AddInstruction( - HloInstruction::CreateParameter(1, shape, "param1")); - builder.AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), {param_1, param_2}, 1)); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 3); - const auto &node = graph_def.node(2); - EXPECT_EQ(node.name(), "Concatenate/concatenate"); - - // Check dimensions. - auto dims_value = GetNodeAttr(node, "dims"); - EXPECT_EQ(dims_value.list().i_size(), 1); - EXPECT_EQ(dims_value.list().i(0), 1); - - // Check shapes. - auto shape_value = GetNodeAttr(node, "_output_shapes"); - EXPECT_EQ(shape_value.list().shape_size(), 1); - EXPECT_EQ(shape_value.list().shape(0).dim_size(), 2); - EXPECT_EQ(shape_value.list().shape(0).dim(0).size(), 2); - EXPECT_EQ(shape_value.list().shape(0).dim(1).size(), 4); -} - -TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { - auto builder = HloComputation::Builder("Const"); - HloInstruction *instruction = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); - OpMetadata metadata; - metadata.set_op_name("x"); - metadata.set_op_type("y"); - instruction->set_metadata(metadata); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 1); - const auto &node = graph_def.node(0); - EXPECT_EQ(GetNodeAttr(node, "value").s(), "123"); - EXPECT_EQ(GetNodeAttr(node, "type").s(), "S32"); - EXPECT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x"); - EXPECT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y"); -} - -TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) { - auto negate_computation = CreateNegateComputation(); - TF_CHECK_OK(generator_.AddComputation(*negate_computation)); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 2); - EXPECT_EQ(graph_def.node(0).name(), "Negate/param0.0"); - EXPECT_EQ(graph_def.node(0).op(), "HloParameter"); - EXPECT_EQ(graph_def.node(1).name(), "Negate/negate"); - EXPECT_EQ(graph_def.node(1).op(), "HloNegate"); - EXPECT_EQ(graph_def.node(1).input_size(), 1); - EXPECT_EQ(graph_def.node(1).input(0), "Negate/param0.0"); -} - -TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) { - auto builder = HloComputation::Builder("GE"); - auto param_1 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - auto param_2 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32_, "param1")); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 3); - EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); - EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); - EXPECT_EQ(graph_def.node(2).input_size(), 2); - EXPECT_EQ(graph_def.node(2).name(), "GE/greater-than-or-equal-to"); - EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); -} - -TEST_F(HloTfGraphBuilderTest, IncorparateTfOpsStructure) { - auto builder = HloComputation::Builder("GE"); - auto param_1 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - auto param_2 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32_, "param1")); - auto ge = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); - OpMetadata metadata; - metadata.set_op_name("x/y"); - metadata.set_op_type("Y"); - ge->set_metadata(metadata); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 3); - EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); - EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); - EXPECT_EQ(graph_def.node(2).input_size(), 2); - EXPECT_EQ(graph_def.node(2).name(), "GE/x/y/greater-than-or-equal-to"); - EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); -} - -TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) { - // Create computations with a diamond-shaped callgraph. - auto negate_computation = CreateNegateComputation(); - auto map1_computation = CreateMapComputation(negate_computation.get()); - auto map2_computation = CreateMapComputation(negate_computation.get()); - - auto builder = HloComputation::Builder(TestName()); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - auto map1 = builder.AddInstruction( - HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get())); - auto map2 = builder.AddInstruction( - HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get())); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2)); - auto computation = builder.Build(); - TF_CHECK_OK(generator_.AddComputation(*computation)); - EXPECT_GT(generator_.GetGraphDef().node_size(), 0); -} - -} // namespace -} // namespace hlo_graph_dumper -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h deleted file mode 100644 index 4458c251dee4af365e39027dd4289925c8890efd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_token.h +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Defines different kinds of tokens in a hlo module string. -// -// You shouldn't need to use this directly unless you're using HloLexer -// directly, and you probably don't need to do that. Use hlo_parser instead. -enum class TokKind { - // Markers - kEof, - kError, - - // Tokens with no info. - kEqual, // = - kComma, // , - kColon, // : - kLsquare, - kRsquare, // [ ] - kLbrace, - kRbrace, // { } - kLparen, - kRparen, // ( ) - - kArrow, // -> - - // Keywords - kw_HloModule, - kw_ENTRY, - kw_ROOT, - kw_true, - kw_false, - kw_maximal, - kw_replicated, - kw_nan, - kw_inf, - - kNegInf, // -inf - - // Typed tokens. - kName, // %foo - kAttributeName, // dimensions= - kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} - kDxD, // [0-9]+(x[0-9]+)+ - kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* - kIdent, // other identifiers - kString, // "abcd\"\n" - kShape, // f32[2,3]{1,0} - kInt, // 42 - kDecimal, // 4.2 -}; - -string TokKindToString(TokKind kind); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 59594ab2f0f70a206c73e998dbfa69c2c5c7ba43..218b33b2ac2b86edc30b2f014ba206c71da37682 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -46,7 +46,7 @@ const Shape& HloPosition::shape() const { string HloPosition::ToString() const { string index_str = - ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : ""; + instruction->shape().IsTuple() ? (" " + index.ToString()) : ""; return StrCat(instruction->name(), index_str); } @@ -56,10 +56,9 @@ std::ostream& operator<<(std::ostream& out, const HloPosition& position) { } string HloUse::ToString() const { - string index_str = - ShapeUtil::IsTuple(instruction->operand(operand_number)->shape()) - ? (" " + operand_index.ToString()) - : ""; + string index_str = instruction->operand(operand_number)->shape().IsTuple() + ? (" " + operand_index.ToString()) + : ""; return StrCat(instruction->name(), ", operand ", operand_number, index_str); } @@ -88,7 +87,7 @@ bool HloValue::operator!=(const HloValue& other) const { } string HloValue::ToShortString() const { - string index_str = ShapeUtil::IsTuple(defining_instruction()->shape()) + string index_str = defining_instruction()->shape().IsTuple() ? defining_index().ToString() : ""; return StrCat(id(), " ", is_phi_ ? "PHI " : "", @@ -210,7 +209,7 @@ std::ostream& operator<<(std::ostream& out, const HloValue& value) { } void HloValueSet::SortAndUniquifyValues() { - std::sort(values_.begin(), values_.end(), HloValue::IdLessThan); + absl::c_sort(values_, HloValue::IdLessThan); values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual), values_.end()); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 77db7b098a38ff4efdcc7447935fae61561c9ff4..56a06a182a236070340075848d301be54c0d9ebd 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -44,12 +44,13 @@ bool IsCallerInstruction(HloInstruction* hlo) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kWhile: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: + case HloOpcode::kSort: case HloOpcode::kFusion: return true; default: @@ -57,15 +58,6 @@ bool IsCallerInstruction(HloInstruction* hlo) { } } -Status ShapeVerifier::Preprocess(HloInstruction* hlo) { - if (!hlo->called_computations().empty() && !IsCallerInstruction(hlo)) { - return InternalError( - "Called computations specified for non-caller instruction %s", - hlo->ToString()); - } - return VerifyNotSparse(hlo->shape()); -} - namespace { Status CheckOperandCount(const HloInstruction* hlo, int expected) { @@ -90,6 +82,21 @@ Status CheckParameterCount(const HloInstruction* calling_instruction, } // namespace +Status ShapeVerifier::Preprocess(HloInstruction* hlo) { + if (!hlo->called_computations().empty() && !IsCallerInstruction(hlo)) { + return InternalError( + "Called computations specified for non-caller instruction %s", + hlo->ToString()); + } + TF_RETURN_IF_ERROR(VerifyNotSparse(hlo->shape())); + + absl::optional arity = HloOpcodeArity(hlo->opcode()); + if (arity) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity)); + } + return Status::OK(); +} + Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) { return CheckUnaryShape(hlo); } @@ -121,14 +128,12 @@ Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) { } Status ShapeVerifier::HandleConvert(HloInstruction* convert) { - TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1)); return CheckShape(convert, ShapeInference::InferConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); } Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) { - TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1)); return CheckShape(convert, ShapeInference::InferBitcastConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); @@ -139,7 +144,6 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) { } Status ShapeVerifier::HandleDot(HloInstruction* dot) { - TF_RETURN_IF_ERROR(CheckOperandCount(dot, 2)); TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferDotOpShape( dot->operand(0)->shape(), dot->operand(1)->shape(), @@ -148,18 +152,16 @@ Status ShapeVerifier::HandleDot(HloInstruction* dot) { } Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { - TF_RETURN_IF_ERROR(CheckOperandCount(convolution, 2)); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferConvolveShape( convolution->operand(0)->shape(), convolution->operand(1)->shape(), - convolution->feature_group_count(), convolution->window(), - convolution->convolution_dimension_numbers())); + convolution->feature_group_count(), convolution->batch_group_count(), + convolution->window(), convolution->convolution_dimension_numbers())); return CheckShape(convolution, expected); } Status ShapeVerifier::HandleFft(HloInstruction* fft) { - TF_RETURN_IF_ERROR(CheckOperandCount(fft, 1)); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(), @@ -167,13 +169,20 @@ Status ShapeVerifier::HandleFft(HloInstruction* fft) { return CheckShape(fft, expected); } -Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) { +Status ShapeVerifier::HandleTriangularSolve(HloInstruction* hlo) { + TF_ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferTriangularSolveShape( + hlo->operand(0)->shape(), hlo->operand(1)->shape(), + hlo->triangular_solve_options())); + return CheckShape(hlo, expected); +} + +Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) { std::vector operand_shapes; for (const HloInstruction* operand : crs->operands()) { operand_shapes.push_back(&operand->shape()); } - return CheckShape(crs, - ShapeInference::InferCrossReplicaSumShape(operand_shapes)); + return CheckShape(crs, ShapeInference::InferAllReduceShape(operand_shapes)); } Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { @@ -185,14 +194,16 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { ShapeInference::InferAllToAllTupleShape(operand_shapes)); } +Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) { + return CheckShape(hlo, ShapeUtil::MakeShape(U32, {})); +} + Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { - TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1)); return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( hlo->operand(0)->shape())); } Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { - TF_RETURN_IF_ERROR(CheckOperandCount(reduce_precision, 1)); return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), reduce_precision->exponent_bits(), @@ -226,7 +237,6 @@ Status ShapeVerifier::CheckOperandAndParameter( } Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1)); HloInfeedInstruction* infeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); @@ -237,7 +247,6 @@ Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { } Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); HloOutfeedInstruction* outfeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); @@ -313,7 +322,6 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { } Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { - TF_RETURN_IF_ERROR(CheckOperandCount(reverse, 1)); return CheckShape( reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), reverse->dimensions())); @@ -324,13 +332,48 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) { return InternalError("Expected at least 1 operand for %s instruction: %s", HloOpcodeString(sort->opcode()), sort->ToString()); } + HloComputation* compare = sort->to_apply(); + + // Check that the 'compare' computation returns a PRED. + Shape compare_shape = compare->root_instruction()->shape(); + if (!ShapesSame(compare_shape, ShapeUtil::MakeShape(PRED, {}))) { + return InternalError( + "The Sort compare computation shape does not lead to a scalar " + "predicate shape: %s", + StringifyShape(compare_shape)); + } + + // Check that the number of parameters of the 'compare' computation is + // correct. + TF_RETURN_IF_ERROR( + CheckParameterCount(sort, compare, sort->operand_count() * 2)); + + // Verify that the operands of the compare computation have the correct scalar + // shapes. + for (int64 parameter_idx = 0; parameter_idx < compare->num_parameters(); + ++parameter_idx) { + int64 operand_idx = parameter_idx / 2; + Shape expected_scalar_shape = ShapeUtil::MakeShape( + sort->operand(operand_idx)->shape().element_type(), {}); + Shape actual_parameter_shape = + compare->parameter_instruction(parameter_idx)->shape(); + if (!ShapeUtil::CompatibleIgnoringFpPrecision(expected_scalar_shape, + actual_parameter_shape)) { + return InternalError( + "Expected the %lld-th parameter of the compare computation of sort " + "to have shape %s, but got %s", + parameter_idx, StringifyShape(expected_scalar_shape), + StringifyShape(actual_parameter_shape)); + } + } + + // Verify that all operand shapes have the same dimensions. for (int64 operand = 1; operand < sort->operand_count(); ++operand) { if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(), sort->operand(operand)->shape())) { return InternalError( - "Expected sort to have to have the same dimensions for the keys " - "and the values. Keys shape is: %s\n, Values shape (operand index " - "%lld) is: %s", + "Expected sort to have to have the same dimensions for all operands. " + "First operand shape is: %s\n, shape (operand index %lld) is: %s", StringifyShape(sort->operand(0)->shape()), operand, StringifyShape(sort->operand(operand)->shape())); } @@ -339,7 +382,6 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) { } Status ShapeVerifier::HandleConstant(HloInstruction* constant) { - TF_RETURN_IF_ERROR(CheckOperandCount(constant, 0)); if (!Cast(constant)->HasLiteral()) { return InternalError("Constant is required to have a valid literal: %s", constant->ToString()); @@ -348,9 +390,11 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) { } Status ShapeVerifier::HandleIota(HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0)); auto* iota = Cast(instruction); - const int64 rank = ShapeUtil::Rank(iota->shape()); + if (!iota->shape().IsArray()) { + return InternalError("Iota does not support non-array result."); + } + const int64 rank = iota->shape().rank(); if (rank == 0) { return InternalError("Iota does not support scalars."); } @@ -363,13 +407,30 @@ Status ShapeVerifier::HandleIota(HloInstruction* instruction) { } Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { - TF_RETURN_IF_ERROR(CheckOperandCount(get_tuple_element, 1)); return CheckShape(get_tuple_element, ShapeInference::InferGetTupleElementShape( get_tuple_element->operand(0)->shape(), get_tuple_element->tuple_index())); } +namespace { +Status SameElementTypesForOperandsAndToApplyParameters( + const HloInstruction& instruction, int64 num_operands_to_check) { + const ProgramShape& to_apply = instruction.to_apply()->ComputeProgramShape(); + for (int i = 0; i < num_operands_to_check; ++i) { + const Shape& parameter_shape = to_apply.parameters(i); + const Shape& operand_shape = instruction.operands()[i]->shape(); + if (!ShapeUtil::SameElementType(parameter_shape, operand_shape)) { + return InvalidArgument( + "Shape mismatch between to_apply computation" + " parameter and operand %d in %s.", + i, instruction.ToString().c_str()); + } + } + return Status::OK(); +} +} // namespace + Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { if (reduce->operand_count() % 2 != 0) { return InternalError( @@ -381,30 +442,40 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { for (const HloInstruction* operand : reduce->operands()) { operand_shapes.push_back(&operand->shape()); } - return CheckShape(reduce, ShapeInference::InferReduceShape( - operand_shapes, reduce->dimensions(), - reduce->to_apply()->ComputeProgramShape())); + TF_RETURN_IF_ERROR( + CheckShape(reduce, ShapeInference::InferReduceShape( + operand_shapes, reduce->dimensions(), + reduce->to_apply()->ComputeProgramShape()))); + + return allow_mixed_precision_ + ? Status::OK() + : SameElementTypesForOperandsAndToApplyParameters( + *reduce, reduce->operands().size() - 1); } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { - TF_RETURN_IF_ERROR(CheckOperandCount(bitcast, 1)); + // Bitcasts are not allowed to change the element type. + if (bitcast->operand(0)->shape().element_type() != + bitcast->shape().element_type()) { + return InternalError( + "Bitcast can not change the element type from %s to %s", + PrimitiveType_Name(bitcast->operand(0)->shape().element_type()), + PrimitiveType_Name(bitcast->shape().element_type())); + } return Status::OK(); } Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { - TF_RETURN_IF_ERROR(CheckOperandCount(broadcast, 1)); // HLO broadcast has no exact analog at the proto level so there is no // ShapeInference method. Check the output shape explicitly. const Shape& operand_shape = broadcast->operand(0)->shape(); // Check for mixed precision. TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape)); - TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == - broadcast->dimensions().size()); - for (int64 operand_dimension = 0; - operand_dimension < ShapeUtil::Rank(operand_shape); + TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size()); + for (int64 operand_dimension = 0; operand_dimension < operand_shape.rank(); ++operand_dimension) { int64 output_dimension = broadcast->dimensions()[operand_dimension]; - TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) && + TF_RET_CHECK((output_dimension < broadcast->shape().rank()) && output_dimension >= 0 && (broadcast->shape().dimensions(output_dimension) == operand_shape.dimensions(operand_dimension))) @@ -414,7 +485,6 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { - TF_RETURN_IF_ERROR(CheckOperandCount(reshape, 1)); // Check for mixed precision. const Shape& operand_shape = reshape->operand(0)->shape(); TF_RET_CHECK(SameElementType(reshape->shape(), operand_shape)); @@ -424,14 +494,12 @@ Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { } Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { - TF_RETURN_IF_ERROR(CheckOperandCount(transpose, 1)); return CheckShape( transpose, ShapeInference::InferTransposeShape( transpose->operand(0)->shape(), transpose->dimensions())); } Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { - TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 0)); return Status::OK(); } @@ -481,7 +549,9 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { const Shape& operand_shape_with_layout = custom_call->operand_shapes_with_layout()[i]; TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(), - operand_shape_with_layout)); + operand_shape_with_layout)) + << custom_call->operand(i)->shape().ToString() << " operand " + << operand_shape_with_layout.ToString(); TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout)); } } @@ -489,7 +559,6 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { } Status ShapeVerifier::HandleSlice(HloInstruction* slice) { - TF_RETURN_IF_ERROR(CheckOperandCount(slice, 1)); return CheckShape(slice, ShapeInference::InferSliceShape( slice->operand(0)->shape(), slice->slice_starts(), @@ -497,21 +566,23 @@ Status ShapeVerifier::HandleSlice(HloInstruction* slice) { } Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { - TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_slice, 2)); - return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( - dynamic_slice->operand(0)->shape(), - dynamic_slice->operand(1)->shape(), - dynamic_slice->dynamic_slice_sizes())); + return CheckShape( + dynamic_slice, + ShapeInference::InferDynamicSliceShape( + dynamic_slice->operand(0)->shape(), + Cast(dynamic_slice)->index_shapes(), + dynamic_slice->dynamic_slice_sizes())); } Status ShapeVerifier::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { - TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_update_slice, 3)); - return CheckShape(dynamic_update_slice, - ShapeInference::InferDynamicUpdateSliceShape( - dynamic_update_slice->operand(0)->shape(), - dynamic_update_slice->operand(1)->shape(), - dynamic_update_slice->operand(2)->shape())); + return CheckShape( + dynamic_update_slice, + ShapeInference::InferDynamicUpdateSliceShape( + dynamic_update_slice->operand(0)->shape(), + dynamic_update_slice->operand(1)->shape(), + Cast(dynamic_update_slice) + ->index_shapes())); } Status ShapeVerifier::HandleTuple(HloInstruction* tuple) { @@ -523,30 +594,39 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) { int64 max_operand_rank = 0; for (const HloInstruction* operand : map->operands()) { operand_shapes.push_back(&operand->shape()); - max_operand_rank = - std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + max_operand_rank = std::max(max_operand_rank, operand->shape().rank()); } // TODO(b/65689298) Remove code below once Map is generalized to accept // arbitrary map dimensions. std::vector map_dims(max_operand_rank); std::iota(map_dims.begin(), map_dims.end(), 0); - return CheckShape(map, ShapeInference::InferMapShape( - operand_shapes, - map->to_apply()->ComputeProgramShape(), map_dims)); + + TF_RETURN_IF_ERROR(CheckShape( + map, + ShapeInference::InferMapShape( + operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims))); + + return allow_mixed_precision_ + ? Status::OK() + : SameElementTypesForOperandsAndToApplyParameters( + *map, map->operands().size()); } Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { - TF_RETURN_IF_ERROR(CheckOperandCount(reduce_window, 2)); - return CheckShape( + TF_RETURN_IF_ERROR(CheckShape( reduce_window, ShapeInference::InferReduceWindowShape( reduce_window->operand(0)->shape(), reduce_window->operand(1)->shape(), reduce_window->window(), - reduce_window->to_apply()->ComputeProgramShape())); + reduce_window->to_apply()->ComputeProgramShape()))); + + return allow_mixed_precision_ + ? Status::OK() + : SameElementTypesForOperandsAndToApplyParameters(*reduce_window, + 1); } Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3)); return CheckShape( instruction, ShapeInference::InferSelectAndScatterShape( @@ -557,7 +637,6 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { } Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { - TF_RETURN_IF_ERROR(CheckOperandCount(xla_while, 1)); TF_RETURN_IF_ERROR( CheckParameterCount(xla_while, xla_while->while_body(), 1)); TF_RETURN_IF_ERROR( @@ -581,7 +660,6 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { } Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { - TF_RETURN_IF_ERROR(CheckOperandCount(conditional, 3)); TF_RETURN_IF_ERROR( CheckParameterCount(conditional, conditional->true_computation(), 1)); TF_RETURN_IF_ERROR( @@ -600,14 +678,12 @@ Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { } Status ShapeVerifier::HandlePad(HloInstruction* pad) { - TF_RETURN_IF_ERROR(CheckOperandCount(pad, 2)); return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(), pad->operand(1)->shape(), pad->padding_config())); } Status ShapeVerifier::HandleSend(HloInstruction* send) { - TF_RETURN_IF_ERROR(CheckOperandCount(send, 2)); return CheckShape(send, ShapeUtil::MakeTupleShape({send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {}), @@ -615,12 +691,10 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) { } Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { - TF_RETURN_IF_ERROR(CheckOperandCount(send_done, 1)); return CheckShape(send_done, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleRecv(HloInstruction* recv) { - TF_RETURN_IF_ERROR(CheckOperandCount(recv, 1)); return CheckShape( recv, ShapeUtil::MakeTupleShape( {ShapeUtil::GetTupleElementShape(recv->shape(), 0), @@ -628,7 +702,6 @@ Status ShapeVerifier::HandleRecv(HloInstruction* recv) { } Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { - TF_RETURN_IF_ERROR(CheckOperandCount(recv_done, 1)); return CheckShape( recv_done, ShapeUtil::MakeTupleShape( @@ -638,7 +711,6 @@ Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { Status ShapeVerifier::HandleBatchNormTraining( HloInstruction* batch_norm_training) { - TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_training, 3)); return CheckShape(batch_norm_training, ShapeInference::InferBatchNormTrainingShape( batch_norm_training->operand(0)->shape(), @@ -649,7 +721,6 @@ Status ShapeVerifier::HandleBatchNormTraining( Status ShapeVerifier::HandleBatchNormInference( HloInstruction* batch_norm_inference) { - TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_inference, 5)); return CheckShape(batch_norm_inference, ShapeInference::InferBatchNormInferenceShape( batch_norm_inference->operand(0)->shape(), @@ -661,7 +732,6 @@ Status ShapeVerifier::HandleBatchNormInference( } Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { - TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_grad, 5)); return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape( batch_norm_grad->operand(0)->shape(), batch_norm_grad->operand(1)->shape(), @@ -683,7 +753,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kConstant: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kCustomCall: case HloOpcode::kDomain: case HloOpcode::kFusion: @@ -694,7 +764,6 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kReducePrecision: - case HloOpcode::kSelect: case HloOpcode::kTupleSelect: case HloOpcode::kSend: case HloOpcode::kSendDone: @@ -730,7 +799,6 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { } // namespace Status ShapeVerifier::HandleGather(HloInstruction* gather) { - TF_RETURN_IF_ERROR(CheckOperandCount(gather, 2)); return CheckShape( gather, ShapeInference::InferGatherShape( @@ -739,7 +807,6 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { } Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { - TF_RETURN_IF_ERROR(CheckOperandCount(scatter, 3)); return CheckShape( scatter, ShapeInference::InferScatterShape( scatter->operand(0)->shape(), scatter->operand(1)->shape(), @@ -757,7 +824,6 @@ Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { } Status ShapeVerifier::HandleAddDependency(HloInstruction* add_dependency) { - TF_RETURN_IF_ERROR(CheckOperandCount(add_dependency, 2)); TF_RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1)); return CheckShape(add_dependency, add_dependency->operand(0)->shape()); } @@ -839,14 +905,12 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1)); return CheckShape(instruction, ShapeInference::InferUnaryOpShape(instruction->opcode(), instruction->operand(0))); } Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); return CheckShape( instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(), instruction->operand(0), @@ -854,7 +918,6 @@ Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { } Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3)); return CheckShape(instruction, ShapeInference::InferTernaryOpShape( instruction->opcode(), instruction->operand(0), @@ -982,7 +1045,7 @@ bool ShapeContainsToken(const Shape& shape) { bool contains_token = false; ShapeUtil::ForEachSubshape( shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsToken(subshape)) { + if (subshape.IsToken()) { contains_token = true; } }); @@ -1230,8 +1293,8 @@ Status CheckFusionInstruction(HloInstruction* fusion) { return Status::OK(); } -// Checks that the non-scalar operand shapes are compatible to the output -// shape, i.e., that there are no implicit broadcasts of size-one dimensions. +// Checks that the operand shapes are compatible to the output shape, i.e., +// that there are no implicit broadcasts. Status CheckElementwiseInstruction(HloInstruction* instruction) { const Shape& out_shape = instruction->shape(); for (HloInstruction* operand : instruction->operands()) { @@ -1270,11 +1333,11 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I // or ComputationLowerer::Visit() TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(broadcast->operand(0)->shape())) + broadcast->operand(0)->shape().rank()) << "Broadcast HLO (" << broadcast->ToShortString() << ") has invalid number of dimensions: " << broadcast->dimensions().size() - << " != " << ShapeUtil::Rank(broadcast->operand(0)->shape()); + << " != " << broadcast->operand(0)->shape().rank(); return Status::OK(); } @@ -1324,7 +1387,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { } Status HandleGetTupleElement(HloInstruction* gte) override { - TF_RET_CHECK(ShapeUtil::IsTuple(gte->operand(0)->shape())); + TF_RET_CHECK(gte->operand(0)->shape().IsTuple()); return Status::OK(); } @@ -1344,7 +1407,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleCrossReplicaSum(HloInstruction* crs) override { + Status HandleAllReduce(HloInstruction* crs) override { if (crs->all_reduce_id().has_value()) { TF_RET_CHECK(crs->all_reduce_id().value() > 0) << "All reduce id must be greater than 0 for " @@ -1375,7 +1438,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { for (HloInstruction* operand : instruction->operands()) { const Shape& operand_shape = operand->shape(); if (LayoutUtil::IsDenseArray(operand_shape) && - ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) { + operand_shape.rank() == result_shape.rank()) { const Layout& operand_layout = operand_shape.layout(); TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout)) << "Instruction shouldn't change layouts " diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index e4d0c3d6957885f1d719fedb5a900de601e397f8..a9b5e9a3e6eec19e125188a192694fcaadfe2322 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -52,9 +52,11 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleTriangularSolve(HloInstruction* hlo) override; + Status HandleAllReduce(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override; + Status HandleReplicaId(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; Status HandleOutfeed(HloInstruction*) override; @@ -168,8 +170,13 @@ class ShapeVerifier : public DfsHloVisitor { // An interface used to encapsulate target-specific verification quirks. class TargetVerifierMetadata { public: + TargetVerifierMetadata(std::function shape_size_function) + : shape_size_function_(shape_size_function) {} + // Returns a target-specific shape size. - virtual int64 ShapeSize(const Shape& shape) const = 0; + int64 ShapeSize(const Shape& shape) const { + return shape_size_function_(shape); + } virtual std::unique_ptr GetVerifier() const = 0; @@ -178,20 +185,23 @@ class TargetVerifierMetadata { TargetVerifierMetadata(const TargetVerifierMetadata&) = delete; TargetVerifierMetadata& operator=(const TargetVerifierMetadata&) = delete; + + private: + // Returns a target-specific shape size. + std::function shape_size_function_; }; // The default implementation of TargetVerifierMetadata, used unless the target // needs to override it. class DefaultVerifierMetadata : public TargetVerifierMetadata { public: - DefaultVerifierMetadata(bool layout_sensitive, bool allow_mixed_precision) - : layout_sensitive_(layout_sensitive), + DefaultVerifierMetadata( + bool layout_sensitive, bool allow_mixed_precision, + std::function shape_size_function) + : TargetVerifierMetadata(shape_size_function), + layout_sensitive_(layout_sensitive), allow_mixed_precision_(allow_mixed_precision) {} - int64 ShapeSize(const Shape& shape) const override { - return ShapeUtil::ByteSizeOf(shape); - } - // Creates a ShapeVerifier that checks that shapes match inferred // expectations. This creates a new verifier every time because ShapeVerifier, // being a DfsHloVisitor, is stateful. We want a clean object for each run of @@ -210,11 +220,14 @@ class DefaultVerifierMetadata : public TargetVerifierMetadata { // the module. class HloVerifier : public HloModulePass { public: - explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision, - std::function - instruction_can_change_layout_func = {}) + explicit HloVerifier( + bool layout_sensitive, bool allow_mixed_precision, + std::function + instruction_can_change_layout_func = {}, + std::function shape_size_func = + [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); }) : target_metadata_(absl::make_unique( - layout_sensitive, allow_mixed_precision)), + layout_sensitive, allow_mixed_precision, shape_size_func)), instruction_can_change_layout_func_( std::move(instruction_can_change_layout_func)) { CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive); diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 4bc557e4e62e7df4e25fda86fe417e84129b464c..523890b3c7268c06cdb6aaa67749f26a1cb62855 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" @@ -27,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -386,6 +388,55 @@ TEST_F(HloVerifierTest, AddWithLayoutChange) { ASSERT_TRUE(status.ok()); } +TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) { + const char* const kScalarIndexDynamicSlice = R"( + HloModule DynamicSlice_module + + ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] { + %original_parameter = s32[2,2,258] parameter(0) + %constant = s32[] constant(0) + %start_index = s32[] parameter(1) + ROOT %dynamic-slice = s32[2,2,258] dynamic-slice(s32[2,2,258] %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258} + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kScalarIndexDynamicSlice, config)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) { + const char* const kScalarIndexDynamicSlice = R"( + HloModule DynamicUpdateSlice_module + + ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] { + %input = s32[1,1,25,1]{3,2,1,0} parameter(0) + %update = s32[1,1,2,1]{3,2,1,0} parameter(1) + %start_index.0 = s32[] parameter(2) + %start_index.1 = s32[] parameter(3) + %start_index.2 = s32[] parameter(4) + %start_index.3 = s32[] parameter(5) + ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3) + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kScalarIndexDynamicSlice, config)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); auto status = verifier().Run(module.get()).status(); @@ -399,8 +450,9 @@ TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) { HloModule SliceWithLayoutChange ENTRY SliceWithLayoutChange { par0 = f32[4,5]{0,1} parameter(0) - par1 = s32[2] parameter(1) - ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1), + par1 = s32[] parameter(1) + par2 = s32[] parameter(2) + ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2), dynamic_slice_sizes={3,4} } )"; @@ -429,5 +481,138 @@ TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) { EXPECT_THAT(status.error_message(), HasSubstr("Instruction shouldn't change layouts")); } + +TEST_F(HloVerifierTest, BitcastCanNotChangeElementType) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY BitcastCanNotChangeElementType { + constant.0 = f32[2] constant({0.0, 0.0}) + ROOT bitcast = s32[2] bitcast(constant.0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Bitcast can not change the element type")); +} + +TEST_F(HloVerifierTest, SelectMixedPrecisionNotAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY SelectMixedPrecisionNotAllowed { + p0 = pred[] parameter(0) + p1 = f32[32] parameter(1) + p2 = bf16[32] parameter(2) + ROOT select = f32[32] select(p0, p1, p2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Seen floating point types of different precisions")); +} + +TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY SelectMixedPrecisionAllowed { + p0 = pred[] parameter(0) + p1 = f32[32] parameter(1) + p2 = bf16[32] parameter(2) + ROOT select = f32[32] select(p0, p1, p2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, IotaNonArrayResult) { + const char* const hlo_string = R"( + HloModule IotaTupleResult + + ENTRY kernelEntry { + ROOT iota = () iota(), iota_dimension=24 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("does not support non-array result")); +} + +static const char* const kMapOperandComputationMismatchHlo = R"( + HloModule MapOperandComputationMismatch + + Computation { + param0 = f32[] parameter(0) + constant = f32[] constant(1) + ROOT add = f32[] add(param0, constant) + } + + ENTRY kernelEntry { + param = f64[] parameter(0) + ROOT map = f32[] map(param), dimensions={}, to_apply=Computation +})"; + +TEST_F(HloVerifierTest, MapOperandComputationMismatch) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kMapOperandComputationMismatchHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT( + status.error_message(), + HasSubstr( + "Shape mismatch between to_apply computation parameter and operand")); +} + +TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kMapOperandComputationMismatchHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +static const char* const kReduceOperandComputationMismatchHlo = R"( + HloModule ReduceOperandComputationMismatch + computation { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + ENTRY kernelEntry { + arg0 = f16[64,64,224,224]{3,2,1,0} parameter(0) + constant = f16[] constant(0) + reduce = f16[64]{0} reduce(arg0, constant), dimensions={0,2,3}, to_apply=computation + })"; + +TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kReduceOperandComputationMismatchHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected instruction to have shape equal to f32[64]")); +} + +TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kReduceOperandComputationMismatchHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index 90904ac00110457bcc3b8974816a7080c4ab89fc..88fc62bd1e2a7830b3f61738a8642308ef4225a7 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -128,9 +128,9 @@ string HumanReadableProfileBuilder::ToString() const { // Sort ops in decreasing order of cycles, and print them. std::vector sorted_ops(op_infos_); - std::sort( - sorted_ops.begin(), sorted_ops.end(), - [](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; }); + absl::c_sort(sorted_ops, [](const OpInfo& a, const OpInfo& b) { + return a.cycles > b.cycles; + }); for (const auto& op : sorted_ops) { print_op(op); } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 1ebb3319779c00fd4afe90606bf336e16349429d..c5d32a4b9ad8c708ec0870173fa72320238e8464 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" namespace xla { -namespace gtl = ::tensorflow::gtl; namespace { using Analysis = IndexedArrayAnalysis; @@ -103,7 +102,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( do { const HloInstruction* instr = stack.back(); - if (cache_.count(instr)) { + if (cache_.contains(instr)) { stack.pop_back(); continue; } @@ -111,9 +110,9 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( switch (FindOrDie(dfs_state_map, instr)) { case kDiscovered: { for (const HloInstruction* operand : instr->operands()) { - if (!cache_.count(operand)) { + if (!cache_.contains(operand)) { stack.push_back(operand); - CHECK(!dfs_state_map.count(operand) || + CHECK(!dfs_state_map.contains(operand) || dfs_state_map[operand] == kDiscovered); dfs_state_map[operand] = kDiscovered; } @@ -1002,7 +1001,7 @@ bool CanFoldDotIntoIndexedArray( absl::Span contracting_dims, absl::Span batch_dims) { absl::optional non_contracting_non_batch_dim = - GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()), + GetOnlyNonContractingNonBatchDim(indexed_array->shape().rank(), contracting_dims, batch_dims); if (!non_contracting_non_batch_dim.has_value()) { VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions"; @@ -1015,7 +1014,7 @@ bool CanFoldDotIntoIndexedArray( return false; } - int64 indexed_array_rank = ShapeUtil::Rank(indexed_array->shape()); + int64 indexed_array_rank = indexed_array->shape().rank(); if (indexed_array->source_dim() < (indexed_array_rank - 2)) { // This restriction can be lifted by inserting reshape nodes. VLOG(3) << tag @@ -1043,7 +1042,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( return nullptr; } - int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); + int64 lhs_rank = lhs->shape().rank(); DotDimensionNumbers new_dim_numbers = dim_numbers; new_dim_numbers.set_lhs_contracting_dimensions( 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1)); @@ -1078,7 +1077,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( return nullptr; } - int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); + int64 rhs_rank = rhs->shape().rank(); DotDimensionNumbers new_dim_numbers = dim_numbers; new_dim_numbers.set_rhs_contracting_dimensions( diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 98246d5403e4aebc2f4d81e52145706355ddd9a9..62107b5a88d4e37552fa5a6384700a9291a9c655 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "absl/strings/ascii.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -43,7 +42,7 @@ class IndexedArrayAnalysisTest : public HloTestBase { string result; for (char c : text) { - if (!isspace(c)) { + if (!absl::ascii_isspace(c)) { result.push_back(c); } else if (!result.empty() && result.back() != ' ') { result.push_back(' '); @@ -99,7 +98,7 @@ TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5] parameter(0) ROOT gather = s32[5,3] gather(operand, indices), offset_dims={1}, @@ -119,7 +118,7 @@ TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed0) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5,2] parameter(0) ROOT gather = s32[5] gather(operand, indices), offset_dims={}, @@ -195,7 +194,7 @@ TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices_a = s32[5] parameter(0) indices_b = s32[2] parameter(1) gather_a = s32[5,3] gather(operand, indices_a), @@ -309,7 +308,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather0) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5] parameter(0) gather = s32[5,4] gather(operand, indices), offset_dims={1}, @@ -330,7 +329,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather1) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,7] parameter(0) gather = s32[5,4,7] gather(operand, indices), offset_dims={1}, @@ -352,7 +351,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather2) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,2,6] constant(s32[3,2,6]{ + operand = s32[3,2,6] constant({ {{1,2,3,4,5,6},{1,2,3,4,5,6}}, {{1,2,3,4,5,6},{1,2,3,4,5,6}}, {{1,2,3,4,5,6},{1,2,3,4,5,6}}}) @@ -377,7 +376,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather3) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,6] constant(s32[2,6]{ + operand = s32[2,6] constant({ {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), @@ -405,7 +404,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather4) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } }) + operand = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 1, 2, 3 } }) i.0 = s64[1,3]{1,0} parameter(0) g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2}, @@ -438,7 +437,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather5) { HloModule ReshapeOfGather ENTRY main { - operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}}) + operand = s32[1,6] constant({{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), offset_dims={1}, @@ -465,7 +464,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather6) { HloModule ReshapeOfGather ENTRY main { - operand = s32[1,2,6] constant(s32[1,2,6]{{ + operand = s32[1,2,6] constant({{ {1,2,3,4,5,6},{1,2,3,4,5,6}}}) indices = s32[1] parameter(0) gather = s32[1,1,6] gather(operand, indices), @@ -496,7 +495,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather7) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,6] constant(s32[2,6]{ + operand = s32[2,6] constant({ {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1,5] parameter(0) gather = s32[1,5,6] gather(operand, indices), @@ -527,7 +526,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold0) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,6] parameter(0) gather = s32[5,4,6] gather(operand, indices), offset_dims={1}, @@ -556,7 +555,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold1) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,5,2] constant(s32[3,5,2]{ + operand = s32[3,5,2] constant({ {{1,2},{3,4},{5,6},{7,8},{9,10}}, {{1,2},{3,4},{5,6},{7,8},{9,10}}, {{1,2},{3,4},{5,6},{7,8},{9,10}}}) @@ -588,7 +587,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold2) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4,1] constant(s32[3,4,1]{ + operand = s32[3,4,1] constant({ {{1},{2},{3},{4}}, {{1},{2},{3},{4}}, {{1},{2},{3},{4}}}) @@ -620,7 +619,7 @@ TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) { HloModule UnaryOpOfGather ENTRY main { - operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + operand = f32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) indices = s32[5] parameter(0) gather = f32[5,4] gather(operand, indices), offset_dims={1}, @@ -645,7 +644,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) { HloModule AddBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -673,7 +672,7 @@ TEST_F(IndexedArrayAnalysisTest, HloModule SubtractBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -701,7 +700,7 @@ TEST_F(IndexedArrayAnalysisTest, HloModule SubtractBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -728,7 +727,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) { HloModule AddBroadcastedVectorWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant_vect = s32[4] constant({10,11,12,13}) constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1} indices = s32[5] parameter(0) @@ -755,7 +754,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) { HloModule AddBroadcastedVectorWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant_vect = s32[5] constant({10,11,12,13,14}) constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0} indices = s32[5] parameter(0) @@ -804,8 +803,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_0) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_lhs = s32[5,4] gather(gather_operand, indices), offset_dims={1}, @@ -831,8 +830,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_1) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[3,3] constant({{1,2,3},{4,5,6},{7,8,9}}) indices = s32[5] parameter(0) dot_lhs = s32[3,5] gather(gather_operand, indices), offset_dims={0}, @@ -859,8 +858,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_2) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[3,5] gather(gather_operand, indices), offset_dims={0}, @@ -888,8 +887,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_3) { HloModule DotOp ENTRY main { - gather_operand = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) - dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[5,3] gather(gather_operand, indices), offset_dims={1}, @@ -917,8 +916,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpWithBatch) { HloModule DotOp ENTRY main { - gather_operand = s32[2,3,2] constant(s32[2,3,2]{{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}}) - dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) + gather_operand = s32[2,3,2] constant({{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}}) + dot_lhs_constant = s32[2,2,3] constant({{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) indices = s32[4] parameter(0) dot_rhs = s32[2,3,4] gather(gather_operand, indices), offset_dims={0,1}, @@ -948,8 +947,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpNegative) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[2,3] constant({{1,2,3},{4,5,6}}) indices = s32[2] parameter(0) dot_lhs = s32[3,2] gather(gather_operand, indices), offset_dims={0}, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 7559ed1bab84b21a4d51bc38db999900befcfad7..f5770eee2250511c0e29e434f224b4ff347142ba 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/fusion_queue.h" @@ -94,6 +95,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kPad: case HloOpcode::kReal: case HloOpcode::kReducePrecision: + case HloOpcode::kReplicaId: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: @@ -126,7 +128,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kConvolution: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kCustomCall: @@ -149,13 +151,16 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kReduceWindow: case HloOpcode::kRemainder: case HloOpcode::kRng: + case HloOpcode::kRsqrt: case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kSort: + case HloOpcode::kSqrt: case HloOpcode::kTanh: case HloOpcode::kTrace: + case HloOpcode::kTriangularSolve: case HloOpcode::kWhile: case HloOpcode::kGetDimensionSize: return true; @@ -173,23 +178,22 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { ShapeUtil::ForEachSubshape( hlo->shape(), [&output_rank](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape)); } }); - return std::count_if(hlo->operands().begin(), hlo->operands().end(), - [output_rank](HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kBroadcast || - operand->opcode() == HloOpcode::kIota) { - return false; - } - if (operand->opcode() == HloOpcode::kConstant && - ShapeUtil::IsEffectiveScalar(operand->shape())) { - return false; - } - return ShapeUtil::TrueRank(operand->shape()) >= - output_rank; - }) <= 1; + return absl::c_count_if( + hlo->operands(), [output_rank](HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kBroadcast || + operand->opcode() == HloOpcode::kIota) { + return false; + } + if (operand->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(operand->shape())) { + return false; + } + return ShapeUtil::TrueRank(operand->shape()) >= output_rank; + }) <= 1; } bool InstructionFusion::CanFuseOnAllPaths( @@ -273,7 +277,7 @@ InstructionFusion::ComputeGloballyUnfusible( ShapeUtil::ForEachSubshape( shape, [&size](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { size += ShapeUtil::ElementsIn(subshape); } }); @@ -408,9 +412,8 @@ class ReversePostOrderFusionQueue : public FusionQueue { } sorted_operand_numbers.push_back(i); } - std::sort( - sorted_operand_numbers.begin(), sorted_operand_numbers.end(), - [&](int64 i, int64 j) { + absl::c_sort( + sorted_operand_numbers, [&](int64 i, int64 j) { // Instructions with higher priority in the queue come first. return ( FindOrDie(post_order_index_, instruction->mutable_operand(i)) > @@ -570,19 +573,42 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { - auto is_reachable = [&](const HloInstruction* a, const HloInstruction* b) { - // A consumer operand may have been multi-output fused into a parallel - // consumer and thus be missing from the original reachability map. - if (!reachability_->IsPresent(a) || !reachability_->IsPresent(b)) { - reachability_ = HloReachabilityMap::Build(consumer->parent()); + absl::flat_hash_set operands; + for (const HloInstruction* operand : consumer->operands()) { + if (operand == producer) { + continue; + } + + // If the reachability map already contains the producer and the operand of + // the consumer, and the producer can reach the operand, then we know for + // sure MultiOutputFusion would create a cycle. If not, we need to do a DFS + // traversal of the computation to verify that this multioutput fusion would + // not create a cycle. + if (reachability_->IsPresent(producer) && + reachability_->IsPresent(operand) && + reachability_->IsReachable(producer, operand)) { + return true; } - return reachability_->IsReachable(a, b); - }; - return absl::c_any_of(consumer->operands(), - [&](const HloInstruction* consumer_operand) { - return consumer_operand != producer && - is_reachable(producer, consumer_operand); - }); + operands.insert(operand->unique_id()); + } + + // Do a DFS on the producer to see if any of the other consumer operands are + // reachable in the current state of the graph. + std::vector worklist = producer->users(); + absl::flat_hash_set visits; + while (!worklist.empty()) { + const HloInstruction* user = worklist.back(); + worklist.pop_back(); + if (operands.count(user->unique_id()) != 0) { + return true; + } + if (visits.count(user->unique_id()) == 0) { + visits.insert(user->unique_id()); + worklist.insert(worklist.end(), user->users().begin(), + user->users().end()); + } + } + return false; } bool InstructionFusion::ShouldFuse(HloInstruction* consumer, diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 58b7135cea7419f13d60ed510ecf7a88126aee48..611cfd404d7622f561f0acc86fc9b05e16eea22e 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -259,8 +259,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { add = f32[4,3]{1,0} add(p0, p0) abs1 = f32[4,3]{1,0} abs(add) log = f32[4,3]{1,0} log(abs1) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 abs2 = f32[4,3]{1,0} abs(log) ROOT root = f32[4,3]{1,0} subtract(abs2, add) })") @@ -290,8 +290,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { p0 = f32[4,3]{1,0} parameter(0) add1 = f32[4,3]{1,0} add(p0, p0) log = f32[4,3]{1,0} log(p0) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 add2 = f32[4,3]{1,0} add(log, add1) ROOT root = f32[4,3]{1,0} subtract(add1, add2) })") @@ -324,8 +324,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { add1 = f32[4,3]{1,0} add(p0, p0) add2 = f32[4,3]{1,0} add(add1, add1) log = f32[4,3]{1,0} log(add2) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 sub1 = f32[4,3]{1,0} subtract(log, add2) sub2 = f32[4,3]{1,0} subtract(add2, add1) ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2) diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index a981d94a999e3d322986bc2bfd56a5b0b5d175fc..8cd936268994c2a25c2c0debe0a003d1d05cbd0b 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -1,12 +1,12 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//visibility:public"]) - load( "//tensorflow/core:platform/default/build_config_root.bzl", "if_static", ) +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + cc_library( name = "interpreter_transfer_manager", srcs = ["interpreter_transfer_manager.cc"], @@ -34,6 +34,7 @@ cc_library( "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -47,8 +48,11 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/compiler/xla/service:map_inliner", + "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:triangular_solve_expander", "//tensorflow/compiler/xla/service:while_loop_simplifier", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/core:lib", "//tensorflow/stream_executor", "@com_google_absl//absl/memory", @@ -115,6 +119,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_headers_lib", + "//tensorflow/stream_executor/host:host_stream", + "//tensorflow/stream_executor/host:host_timer", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 3a5177c418e3af8253df228a51f2fc0901d10041..792773c676984aa280c1b20cb7fd0fc7c9425f6c 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -21,6 +21,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" @@ -31,7 +33,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/service/map_inliner.h" +#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/triangular_solve_expander.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" @@ -40,12 +44,51 @@ limitations under the License. namespace xla { namespace interpreter { +namespace { + +// Handles custom_call ops during evaluation by routing them through the global +// CPU registry used by other CPU-based backends. +StatusOr HandleEvaluatorCustomCall( + HloInstruction* custom_call, absl::Span operands) { + // Find the target C function in the global registry. + auto* registry = xla::cpu::CustomCallTargetRegistry::Global(); + void* target_fn = registry->Lookup(custom_call->custom_call_target()); + if (!target_fn) { + return NotFound("Custom call target '%s' was not registered", + custom_call->custom_call_target()); + } + + // Populate pointers to operand and output literal data. + std::vector operand_data; + operand_data.reserve(operands.size()); + for (const auto* literal : operands) { + operand_data.push_back(literal->untyped_data()); + } + auto output = Literal::CreateFromShape(custom_call->shape()); + void* output_data = output.untyped_data(); + + // Call the target function matching the C ABI used by the CPU backends. + auto* typed_fn = reinterpret_cast(target_fn); + (*typed_fn)(output_data, operand_data.data()); + + return std::move(output); +} + +} // namespace + Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); + pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), LayoutAssignment::InstructionCanChangeLayout); + + ReducePrecisionInsertion::AddPasses( + &pipeline, hlo_module->config().debug_options(), + ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); + return pipeline.Run(hlo_module).status(); } @@ -75,10 +118,15 @@ StatusOr> InterpreterCompiler::RunBackend( // In this case we are using an HloEvaluator at execution time, so we don't // need to compile anything + auto evaluator = absl::make_unique(); + evaluator->set_use_fast_path( + hlo_module->config().debug_options().xla_hlo_evaluator_use_fast_path()); + evaluator->set_custom_call_handler(HandleEvaluatorCustomCall); + // Create executable from only the Hlo module. std::unique_ptr executable = - absl::make_unique( - std::move(hlo_module), absl::make_unique()); + absl::make_unique(std::move(hlo_module), + std::move(evaluator)); return std::move(executable); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index de9204011ce5ba8a9fc2871c6bd7120b6ed371b5..7a6ebdef708bcc3a92fbd8618db0c42c35e6ce8b 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -68,6 +68,18 @@ StatusOr InterpreterExecutable::ExecuteOnStream( "Mismatch between argument count and graph parameter count."); } + // Check that the args have the right shape. + for (int64 i = 0; i < computation->num_parameters(); ++i) { + const auto& expected_shape = computation->parameter_instruction(i)->shape(); + const auto& actual_shape = arguments[i]->on_device_shape(); + if (!ShapeUtil::Equal(expected_shape, actual_shape)) { + return InvalidArgument( + "Shape mismatch on parameter %d. Expected %s, but was %s.", i, + ShapeUtil::HumanString(expected_shape), + ShapeUtil::HumanString(actual_shape)); + } + } + TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, TransferManager::GetForPlatform(platform)); @@ -86,8 +98,8 @@ StatusOr InterpreterExecutable::ExecuteOnStream( { tensorflow::mutex_lock lock(evaluator_lock_); evaluator_->ResetVisitStates(); - TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate( - *computation, arg_literals)); + TF_ASSIGN_OR_RETURN(result_literal, + evaluator_->Evaluate(*computation, arg_literals)); } // Transform the result literal back into a ShapedBuffer. @@ -117,7 +129,7 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream( } /*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) { - if (ShapeUtil::IsOpaque(shape)) { + if (shape.IsOpaque()) { return sizeof(void*); } return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index eddef850cf5250b85b564c1e6c92d1cc8ecd1a43..aa791ea195e7a88fd8ad28fd0b60c88dea8a6928 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -147,12 +147,9 @@ bool LayoutConstraints::OperandBufferForwarded( PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction); PointsToSet::BufferSet* operand_buffers = GetBufferSet(instruction->operand(operand_no)); - for (const LogicalBuffer* output_buffer : *output_buffers) { - if (operand_buffers->count(output_buffer) > 0) { - return true; - } - } - return false; + return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) { + return operand_buffers->count(b) > 0; + }); } Status LayoutConstraints::SetBufferLayout(const Layout& layout, @@ -256,7 +253,7 @@ Status LayoutConstraints::SetArrayOperandLayout( const Layout& layout, const HloInstruction* instruction, int64 operand_no, bool mandatory, bool dfs) { const HloInstruction* operand = instruction->operand(operand_no); - TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); + TF_RET_CHECK(operand->shape().IsArray()); Shape shape(operand->shape()); *shape.mutable_layout() = layout; TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape)); @@ -314,7 +311,7 @@ Status LayoutConstraints::SetInstructionLayout( CHECK_EQ(1, buffers.size()); CHECK_EQ(buffers[0]->instruction(), instruction); - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { return SetBufferLayout(subshape.layout(), *buffers[0], mandatory); } else { return Status::OK(); @@ -406,7 +403,7 @@ Status LayoutAssignment::BuildHostChannelConstraints( instruction->opcode() == HloOpcode::kRecv) { const Shape& data_shape = ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0); - TF_RET_CHECK(ShapeUtil::IsArray(data_shape)); + TF_RET_CHECK(data_shape.IsArray()); TF_RET_CHECK(LayoutUtil::HasLayout(data_shape)); const Layout* prev_layout = host_channel_constraints_.ConstrainChannel( send_recv_instr->channel_id(), data_shape.layout()); @@ -489,7 +486,7 @@ Status LayoutAssignment::AddMandatoryConstraints( if (instruction->opcode() == HloOpcode::kSend) { // TODO(b/68493863): Change to use SetOperandLayout(). const Shape send_buffer_shape = instruction->operand(0)->shape(); - TF_RET_CHECK(ShapeUtil::IsArray(send_buffer_shape)); + TF_RET_CHECK(send_buffer_shape.IsArray()); Shape new_buffer_shape = get_channel_constraints(instruction) ->LayoutShapeForChannel(send_buffer_shape, @@ -499,7 +496,7 @@ Status LayoutAssignment::AddMandatoryConstraints( } else { const Shape recv_buffer_shape = ShapeUtil::GetTupleElementShape(instruction->shape(), 0); - TF_RET_CHECK(ShapeUtil::IsArray(recv_buffer_shape)); + TF_RET_CHECK(recv_buffer_shape.IsArray()); TF_ASSIGN_OR_RETURN( const LogicalBuffer* buffer, constraints->points_to_analysis().GetBufferDefinedAt(instruction, @@ -520,7 +517,7 @@ Status LayoutAssignment::AddMandatoryConstraints( } // TODO(b/68493863): Change to use SetOperandLayout(). const Shape& buffer_shape = instruction->operand(0)->shape(); - TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape)); + TF_RET_CHECK(buffer_shape.IsArray()); Shape new_buffer_shape = get_channel_constraints(instruction) ->LayoutShapeForChannel(buffer_shape, all_reduce_id); @@ -780,7 +777,7 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( << ShapeUtil::HumanString(instruction->shape()) << " instruction: " << instruction->ToString(); - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { // Copy tuple elements which have differing layouts. std::vector element_copies; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); @@ -811,7 +808,7 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( shape_with_layout, tuple_copy->mutable_shape())); return tuple_copy; - } else if (ShapeUtil::IsArray(instruction->shape())) { + } else if (instruction->shape().IsArray()) { HloInstruction* copy = instruction->parent()->AddInstruction(HloInstruction::CreateUnary( instruction->shape(), HloOpcode::kCopy, instruction)); @@ -988,11 +985,10 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( const Layout& output_layout, const HloInstruction* instruction, int64 operand_no) { const HloInstruction* operand = instruction->operand(operand_no); - CHECK(ShapeUtil::IsArray(instruction->shape())); - CHECK(ShapeUtil::IsArray(operand->shape())); + CHECK(instruction->shape().IsArray()); + CHECK(operand->shape().IsArray()); if (!ShapeUtil::IsScalar(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == - ShapeUtil::Rank(instruction->shape()) && + operand->shape().rank() == instruction->shape().rank() && !instruction_can_change_layout_func_(instruction)) { // Propagate the result layout to the operand layout if the instruction // requires the same layout out for the result and the operand. @@ -1012,7 +1008,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( // operations. For similar reasons, if the operand and output have the same // rank, try to match the operand's layout to the output. if (ShapeUtil::TrueRank(operand->shape()) == 1 && - ShapeUtil::Rank(instruction->shape()) == 1) { + instruction->shape().rank() == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. return nullptr; } @@ -1026,7 +1022,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { return absl::make_unique(operand_shape.layout()); } - if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) { + if (operand_shape.rank() == output_shape.rank()) { *operand_shape.mutable_layout() = output_layout; if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { @@ -1045,7 +1041,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (instruction->opcode() == HloOpcode::kTranspose) { // Pick the operand layout that makes the transpose a bitcast. - int64 rank = ShapeUtil::Rank(instruction->shape()); + int64 rank = instruction->shape().rank(); std::vector new_minor_to_major(rank); for (int64 i = 0; i < rank; ++i) { int64 output_dim = LayoutUtil::Minor(output_layout, i); @@ -1066,11 +1062,10 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( int64 operand_no) { const HloInstruction* operand = user->operand(operand_no); - CHECK(ShapeUtil::IsArray(user->shape()) && - ShapeUtil::IsArray(operand->shape())); + CHECK(user->shape().IsArray() && operand->shape().IsArray()); if (!ShapeUtil::IsScalar(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) && + operand->shape().rank() == user->shape().rank() && !instruction_can_change_layout_func_(user)) { // Assign users the same layout as the operand. return absl::make_unique(operand_layout); @@ -1083,7 +1078,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( // reshape is a bitcast when using the same layout. This may avoid copy // operations. For similar reasons, if the operand and output have the same // rank, try to match the outputs's layout to the operand. - if (ShapeUtil::Rank(operand->shape()) == 1 && + if (operand->shape().rank() == 1 && ShapeUtil::TrueRank(user->shape()) == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. return nullptr; @@ -1098,7 +1093,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { return absl::make_unique(output_shape.layout()); } - if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) { + if (operand->shape().rank() == output_shape.rank()) { *output_shape.mutable_layout() = operand_layout; if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { @@ -1117,7 +1112,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (user->opcode() == HloOpcode::kTranspose) { // Pick the user layout that makes the transpose a bitcast. - int64 rank = ShapeUtil::Rank(user->shape()); + int64 rank = user->shape().rank(); std::vector new_minor_to_major(rank); auto inverse_dimensions = InversePermutation(user->dimensions()); for (int64 i = 0; i < rank; ++i) { @@ -1193,7 +1188,7 @@ std::vector> GetArrayUsesOfBuffer( CHECK(buffer.IsArray()); std::vector> uses; for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) { - if (!ShapeUtil::IsArray(buffer_alias.instruction()->shape())) { + if (!buffer_alias.instruction()->shape().IsArray()) { continue; } // This alias must be the top-level (index == {}) of the instruction's @@ -1227,7 +1222,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) { for (const LogicalBuffer* buffer : buffers) { if (constraints->BufferLayout(*buffer) == nullptr && - ShapeUtil::IsArray(buffer->shape())) { + buffer->shape().IsArray()) { TF_RETURN_IF_ERROR(constraints->SetBufferLayout( ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(), *buffer, /*mandatory=*/true)); @@ -1238,6 +1233,23 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( }); } +namespace { +// A transpose or a reshape that only changes trivial dimensions have meaningful +// layouts that are valuable to propagate in a depthfirst manner to avoid +// unassigned layouts in the graph. +bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) { + switch (hlo.opcode()) { + case HloOpcode::kReshape: + return std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()); + case HloOpcode::kTranspose: + return true; + default: + return false; + } +} + +} // namespace + Status LayoutAssignment::PropagateOperandConstraint( const OperandLayoutConstraint& operand_constraint, LayoutConstraints* constraints) { @@ -1258,11 +1270,10 @@ Status LayoutAssignment::PropagateOperandConstraint( // layout for the operands with the same ranks. const HloInstruction* operand = operand_constraint.operand(); const HloInstruction* user = operand_constraint.instruction(); - if (!ShapeUtil::IsArray(operand->shape())) { + if (!operand->shape().IsArray()) { return Status::OK(); } - if (instruction_can_change_layout_func_(user) && - !ShapeUtil::IsArray(user->shape())) { + if (instruction_can_change_layout_func_(user) && !user->shape().IsArray()) { return Status::OK(); } @@ -1273,7 +1284,7 @@ Status LayoutAssignment::PropagateOperandConstraint( return Status::OK(); } - int64 operand_rank = ShapeUtil::Rank(operand->shape()); + int64 operand_rank = operand->shape().rank(); if (operand_rank <= 1) { return Status::OK(); } @@ -1288,7 +1299,7 @@ Status LayoutAssignment::PropagateOperandConstraint( continue; } const HloInstruction* sibling = user->operand(operand_no); - const int64 sibling_rank = ShapeUtil::Rank(sibling->shape()); + const int64 sibling_rank = sibling->shape().rank(); if (sibling_rank <= 1) { continue; } @@ -1317,16 +1328,16 @@ Status LayoutAssignment::PropagateOperandConstraint( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsTuple(subshape)) { + if (subshape.IsTuple()) { return Status::OK(); } - if (ShapeUtil::Rank(subshape) <= 1) { + if (subshape.rank() <= 1) { return Status::OK(); } // Assign the right layout to input fusion of higher rank reduce // operations. - if (ShapeUtil::Rank(subshape) != ShapeUtil::Rank(operand->shape())) { + if (subshape.rank() != operand->shape().rank()) { return Status::OK(); } // TODO(b/67641796): Are there cases except fusion that use this code @@ -1354,10 +1365,10 @@ Status LayoutAssignment::PropagateOperandConstraint( } TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsTuple(subshape)) { + if (subshape.IsTuple()) { return Status::OK(); } - if (ShapeUtil::Rank(subshape) <= 1) { + if (subshape.rank() <= 1) { return Status::OK(); } TF_ASSIGN_OR_RETURN( @@ -1373,7 +1384,7 @@ Status LayoutAssignment::PropagateOperandConstraint( TF_RETURN_IF_ERROR(constraints->SetBufferLayout( *layout, *buffer, /*mandatory=*/user->opcode() == HloOpcode::kReduce, - /*dfs=*/false)); + /*dfs=*/InstructionShouldPropagateDepthFirst(*user))); } } return Status::OK(); @@ -1401,8 +1412,8 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands( } if (!instruction_can_change_layout_func_(instruction)) { // Copy the layout to the operand. - if (buffer.IsArray() && ShapeUtil::IsArray(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == + if (buffer.IsArray() && operand->shape().IsArray() && + operand->shape().rank() == LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) { TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( buffer_constraint.layout(), instruction, operand_no, @@ -1410,7 +1421,7 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands( } } else { if (!buffer.IsTopLevel() || - !ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + !instruction->operand(operand_no)->shape().IsArray()) { continue; // Don't touch buffers that are internal to a tuple. } VLOG(6) << "Propagating constraint to operand " << operand_no << " of " @@ -1423,11 +1434,9 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands( ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(), instruction, operand_no); if (operand_layout != nullptr) { - // Do not propagate operand constraints of transposes and reshapes, it - // tends to create really bad layouts. TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( *operand_layout, instruction, operand_no, /*mandatory=*/false, - /*dfs=*/false)); + /*dfs=*/InstructionShouldPropagateDepthFirst(*instruction))); } } else { VLOG(6) << "Operand already has a constraint " @@ -1497,7 +1506,7 @@ StatusOr InferArrayLayout( // This function should only be called for array shapes which don't yet have // layouts. const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index); - TF_RET_CHECK(ShapeUtil::IsArray(subshape)); + TF_RET_CHECK(subshape.IsArray()); TF_RET_CHECK(!subshape.has_layout()); // The instruction should not define the buffer at this index. @@ -1576,8 +1585,9 @@ Status SetFusionLayouts(HloInstruction* fusion) { fused_instruction->mutable_shape())); } else if (fused_instruction->opcode() == HloOpcode::kInfeed) { // Nop; leave the infeed layout alone. - } else { + } else if (fusion->fusion_kind() != HloInstruction::FusionKind::kCustom) { // Other instructions don't have layouts inside of fusion nodes. + // But do not clear layouts for other instructions in custom fusion nodes. LayoutUtil::ClearLayout(fused_instruction->mutable_shape()); } } @@ -1615,7 +1625,7 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, for (const LogicalBuffer* buffer : constraints.points_to_analysis().GetBuffersDefinedByInstruction( instruction)) { - if (!ShapeUtil::IsArray(buffer->shape())) { + if (!buffer->shape().IsArray()) { continue; } @@ -1639,7 +1649,7 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( instruction->mutable_shape(), [instruction, &constraints](Shape* subshape, const ShapeIndex& index) { - if (subshape->has_layout() || !ShapeUtil::IsArray(*subshape)) { + if (subshape->has_layout() || !subshape->IsArray()) { return Status::OK(); } // Set Layout of subshape to match layout of LogicalBuffer which @@ -2012,7 +2022,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kConditional: case HloOpcode::kConvert: case HloOpcode::kCos: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kDivide: @@ -2048,6 +2058,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kRemainder: case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: + case HloOpcode::kRsqrt: case HloOpcode::kScatter: case HloOpcode::kSelect: case HloOpcode::kSelectAndScatter: @@ -2058,8 +2069,10 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kSin: case HloOpcode::kSlice: case HloOpcode::kSort: + case HloOpcode::kSqrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: + case HloOpcode::kTriangularSolve: case HloOpcode::kTupleSelect: case HloOpcode::kWhile: return false; @@ -2085,6 +2098,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kReduce: + case HloOpcode::kReplicaId: case HloOpcode::kReshape: case HloOpcode::kRng: case HloOpcode::kSend: @@ -2100,8 +2114,8 @@ bool LayoutAssignment::InstructionCanChangeLayout( /* static */ bool LayoutAssignment::IsAtMostRank1(const Shape& shape) { - if (ShapeUtil::IsArray(shape)) { - return ShapeUtil::Rank(shape) <= 1; + if (shape.IsArray()) { + return shape.rank() <= 1; } return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) { return IsAtMostRank1(subshape); @@ -2123,7 +2137,7 @@ Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kCopy && - added_copies_.count(instruction) > 0) { + added_copies_.contains(instruction)) { VLOG(5) << "Removing added copy: " << instruction->ToString(); TF_RETURN_IF_ERROR( instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 3b081de3c7826c3c11a7d87d542835d0ecce1b7e..5701cb5b025e563247d46d0d24f81a5f886fc23b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -243,7 +243,7 @@ class ChannelLayoutConstraints { // Returns true if channel_id has a layout constraint. bool IsChannelConstrained(int64 channel_id) const { - return constraints_.count(channel_id) > 0; + return constraints_.contains(channel_id); } // Given `shape`, apply the layout for `channel_id`. `channel_id` must already @@ -276,7 +276,7 @@ class ChannelLayoutConstraints { } private: - std::unordered_map constraints_; + absl::flat_hash_map constraints_; }; // HLO pass which assigns layouts to all instructions in the HLO module while diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 5c661bfacb08fe27f3cbdc1fb9db083315166008..c8cf3c47d380012fdb0206c0d20d67e6a13017ae 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -528,8 +528,7 @@ class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment { for (int64 operand_no = 0; operand_no < instruction->operand_count(); ++operand_no) { const HloInstruction* operand = instruction->operand(operand_no); - if (ShapeUtil::Rank(instruction->shape()) != - ShapeUtil::Rank(operand->shape())) { + if (instruction->shape().rank() != operand->shape().rank()) { continue; } TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( @@ -847,12 +846,12 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ENTRY entry_computation { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 - token = token[] after-all() - recv = (f32[2,2], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=1} + token0 = token[] after-all() + recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1} recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1, sharding={maximal device=1} ROOT root = f32[2,2] get-tuple-element(recv-done), index=0 - send = (f32[2,2], u32[], token[]) send(gte, token), channel_id=1, + send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1, sharding={maximal device=0} send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0} } @@ -894,11 +893,11 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { ENTRY entry_computation { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 - ar.0 = f32[2,2] cross-replica-sum(gte), + ar.0 = f32[2,2] all-reduce(gte), all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=0} - const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) - ROOT ar.1 = f32[2,2] cross-replica-sum(const), + const = f32[2,2] constant({{0,1},{2,3}}) + ROOT ar.1 = f32[2,2] all-reduce(const), all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=1} })"; @@ -961,8 +960,9 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange { par0 = f32[3,4]{1,0} parameter(0) par1 = f32[4,5]{0,1} parameter(1) - par2 = s32[2] parameter(2) - dslice0 = f32[3,4] dynamic-slice(par1, par2), dynamic_slice_sizes={3,4} + par2 = s32[] parameter(2) + par3 = s32[] parameter(3) + dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4} ROOT add0 = f32[3,4]{1,0} add(par0,dslice0) } )"; @@ -983,7 +983,7 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { m::Parameter(), m::DynamicSlice( m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy), - m::Parameter(2))))); + m::Parameter(2), m::Parameter(3))))); } TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 728a66b388f0f9af480ff88b5e96990a26e36af5..c5d59fb28e02ce229967fb3856012d608fb83c5d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -39,7 +39,6 @@ cc_library( "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm//:core", ], @@ -169,6 +168,7 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm//:core", diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 643ecd0fbaa546c551097b29e74ccd49418e1466..ce3d922ca7a9bdea3a520959a8b8d284bc3e0d64 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -81,9 +81,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, if (hlo.opcode() == HloOpcode::kParameter) { const std::vector& parameter_instructions = module_.entry_computation()->parameter_instructions(); - if (std::find(parameter_instructions.begin(), - parameter_instructions.end(), - &hlo) != parameter_instructions.end()) { + if (absl::c_linear_search(parameter_instructions, &hlo)) { array->MarkInvariantOverWholeProgram(context_); } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index 2b46b3c3964b15548dbacc8b0ada0047a0fa85b6..12e2f449e23ac2511aac576fed893f5a9ef510c0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -76,15 +76,12 @@ class AliasAnalysis { // A map from a buffer slice to metadata corresponding to its alias.scope // metadata. The index kParameterAliasSet is used to hold aliasing // information for parameters. - absl::flat_hash_map + absl::flat_hash_map alias_scope_metadata_; // A map from a buffer slice to metadata corresponding to its noalias // metadata. - absl::flat_hash_map - noalias_metadata_; + absl::flat_hash_map noalias_metadata_; }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc index bdce4a171b8a58f617f1d56e6cf6db5354846703..1ea5a42b0b398818b0946eaa9e214100007bada4 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,14 +41,26 @@ static const HloInstruction& InstrForConstantBufferAllocation( return *const_instr; } -string ConstantBufferAllocationToGlobalName( - const BufferAllocation& allocation) { - string instr_name = InstrForConstantBufferAllocation(allocation).name(); +string SanitizeConstantName(const HloInstruction& instr) { + CHECK_EQ(instr.opcode(), HloOpcode::kConstant); + string instr_name = instr.name(); for (char& c : instr_name) { - if (c == '.') { + // Having a hyphen or a dot in a global variable name can crash the LLVM PTX + // backend. + if (c == '.' || c == '-') { c = '_'; } } + return instr_name; +} + +string ConstantBufferAllocationToGlobalName( + const BufferAllocation& allocation) { + const HloInstruction& instr = InstrForConstantBufferAllocation(allocation); + string instr_name = instr.name(); + // Check that names are sanitized and stored in the HLO instructions + // before constant buffer allocation. + DCHECK_EQ(instr_name, SanitizeConstantName(instr)); return absl::StrCat("buffer_for_", instr_name); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h index bfb6eecb87f6a1b756b3a8da3377f608dd7f0be7..03e98a66900095889292cbff9d9924a9abe83ab0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h @@ -20,6 +20,10 @@ limitations under the License. namespace xla { namespace llvm_ir { +// Sanitizes the HLO constant instruction name so that it can be used for the +// name of the corresponding constant buffer. In particular, it replaces . and +// - with _. +string SanitizeConstantName(const HloInstruction& instr); // In XLA:GPU we map constant buffer allocations to globals in the generated // LLVM IR. This function gives us the name of the global variable a constant // buffer is mapped to. Not used on XLA:CPU. diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 4d7f36d9f8b565a819edf0631efc5c7a58c4f87f..3acceccfa556103c15fe229c41e96e618ac59c80 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -36,19 +36,20 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, // EmitFusedDynamicUpdateSliceInPlace. // // Emits a sequential loop if launch_dimensions is null. +using IndexGenerator = std::function(int64)>; + static Status EmitDynamicUpdateSliceInPlaceImpl( - const Shape& update_shape, const ElementGenerator& start_indices_generator, + const Shape& update_shape, const IndexGenerator& start_indices_generator, bool is_signed, ElementGenerator update_array_generator, const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions, absl::string_view name, llvm::IRBuilder<>* b) { const Shape& output_shape = output_array.GetShape(); // Read start indices from start_indices_generator. - const int64 rank = ShapeUtil::Rank(output_shape); + const int64 rank = output_shape.rank(); IrArray::Index start_index(b->getInt64Ty(), rank); for (int64 i = 0; i < rank; ++i) { - IrArray::Index dim_index({b->getInt64(i)}); - TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); + TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(i)); llvm::Value* output_dim_size = llvm::ConstantInt::get( start_index[i]->getType(), output_shape.dimensions(i)); llvm::Value* update_dim_size = llvm::ConstantInt::get( @@ -112,8 +113,9 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, Shape output_shape = output_array.GetShape(); Shape update_shape = update_array.GetShape(); - ElementGenerator start_indices_generator = [&](const IrArray::Index& index) { - return start_indices_array.EmitReadArrayElement(index, b); + IndexGenerator start_indices_generator = [&](int64 index) { + return operand_arrays[2 + index].EmitReadArrayElement( + IrArray::Index(b->getInt64Ty()), b); }; ElementGenerator update_array_generator = [&](const IrArray::Index& index) { return update_array.EmitReadArrayElement(index, b); @@ -165,9 +167,12 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( elemental_emitter); TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter)); ElementGenerator update_array_generator = fused_emitter.GetGenerator(update); - ElementGenerator start_indices_generator = - fused_emitter.GetGenerator(start_indices); + IndexGenerator start_indices_generator = [&](int64 index) { + ElementGenerator element_generator = + fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index)); + return element_generator(IrArray::Index(b->getInt64Ty())); + }; bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape()); return EmitDynamicUpdateSliceInPlaceImpl( update_shape, start_indices_generator, is_signed, update_array_generator, diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 38f2b5da23a7b92e4547dceaba011ce654977da3..e440f05e2b2f0d4a2a4c7b326b4881183de4d235 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -35,7 +35,7 @@ using llvm_ir::IrArray; Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { indexed_generators_[hlo] = [=](const IrArray::Index& index) -> StatusOr { - if (generated_value_cache_[hlo].count(index.multidim()) > 0) { + if (generated_value_cache_[hlo].contains(index.multidim())) { llvm::Value* generated_value = generated_value_cache_[hlo][index.multidim()]; llvm::BasicBlock* generated_value_bb = nullptr; @@ -115,7 +115,7 @@ Status FusedIrEmitter::HandleGetTupleElement( /*alignment=*/1, tuple_ptr, b_, module_); }; - if (!ShapeUtil::IsTuple(get_tuple_element->shape())) { + if (!get_tuple_element->shape().IsTuple()) { indexed_generators_[get_tuple_element] = [=](const IrArray::Index& index) -> StatusOr { // TODO(b/34080002) Add aliasing information to tuple element IrArray. diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 1b9c61f6700e2a1309b21e499f4a9e2439ed3702..e6d52a580c04a920d3f0e8ed6f39c1cae587cf1b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" @@ -134,8 +135,9 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { // Cache of generated values, lest we regenerate an element of a node with // multiple outgoing edges - std::unordered_map, llvm::Value*>> + absl::flat_hash_map< + const HloInstruction*, + absl::flat_hash_map, llvm::Value*>> generated_value_cache_; }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 67f7423121177e2ca1e3384341dad2644c8f5e34..8ee07ae8331e986f9d271be5e39065f0d87853b1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -61,7 +61,7 @@ void IrArray::Index::Delinearize(std::vector* multidim, IrArray::Index::Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b) - : multidim_(ShapeUtil::Rank(shape)), + : multidim_(shape.rank()), linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { @@ -104,8 +104,8 @@ IrArray::Index::Index(absl::Span multidim, CHECK(LayoutUtil::HasLayout(shape)); } -IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) - : base_ptr_(base_ptr), shape_(&shape) { +IrArray::IrArray(llvm::Value* base_ptr, Shape shape) + : base_ptr_(base_ptr), shape_(std::move(shape)) { TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); CHECK(base_ptr_->getType()->isPointerTy()); int depth = 0; @@ -117,10 +117,10 @@ IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) ++depth; } - if (!ShapeUtil::IsArray(*shape_) || ShapeUtil::IsScalar(*shape_)) { + if (!shape_->IsArray() || ShapeUtil::IsScalar(*shape_)) { DCHECK(depth == 1 || depth == 0) << depth; } else { - DCHECK_EQ(depth, ShapeUtil::Rank(*shape_)) << shape.ShortDebugString(); + DCHECK_EQ(depth, shape_->rank()) << shape.ShortDebugString(); } } @@ -137,12 +137,12 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( const Shape& output_shape, const Shape& input_shape, llvm::IRBuilder<>* builder) const { const auto& target_index = *this; - CHECK_EQ(target_index.size(), ShapeUtil::Rank(output_shape)); + CHECK_EQ(target_index.size(), output_shape.rank()); std::vector> common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); std::vector source_multidim_index( - ShapeUtil::Rank(input_shape), llvm::UndefValue::get(index_type_)); + input_shape.rank(), llvm::UndefValue::get(index_type_)); // We compute the source indices in each common factor from only the target // indices in the same common factor. for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { @@ -257,7 +257,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( const Shape& shape, const Shape& operand_shape, absl::Span dimension_mapping, llvm::IRBuilder<>* builder) const { - int64 rank = ShapeUtil::Rank(operand_shape); + int64 rank = operand_shape.rank(); std::vector source_index(rank); for (int64 i = 0; i < rank; ++i) { source_index[i] = multidim_[dimension_mapping[i]]; @@ -271,7 +271,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( // The other dimensions can be masked out with a div and a mod operation. std::vector logical_to_physical = LayoutUtil::MakeLogicalToPhysical(shape.layout()); - int64 output_rank = ShapeUtil::Rank(shape); + int64 output_rank = shape.rank(); // The minimum physical dimension that is broadcasted. int64 min_broadcasted_dimension = output_rank; // The maximum physical dimension that is broadcasted. @@ -348,7 +348,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, // over higher-rank arrays. return base_ptr_; } - CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_)); + CHECK_EQ(index.size(), shape_->rank()); if (index.LinearValidOnShape(*shape_)) { llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index d6d84994ee147f4b8c1a333b0eaccdf6e0a2219b..b706ebd311cbb706e7e4698b93319e37e664d10a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -130,6 +130,11 @@ class IrArray { CHECK_LE(index, size()); mutable_multidim().insert(mutable_multidim().begin() + index, value); } + void InsertAt(int64 index, int64 count, llvm::Value* value) { + CHECK_LE(index, size()); + mutable_multidim().insert(mutable_multidim().begin() + index, count, + value); + } using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; @@ -189,6 +194,8 @@ class IrArray { return llvm::ConstantInt::get(index_type_, c); } + void ClearLinearIndex() { linear_ = nullptr; } + private: // Changing the multi-dimensional index invalidates the linear index. std::vector& mutable_multidim() { @@ -220,11 +227,11 @@ class IrArray { }; // Default constructor. Constructs an IrArray in a null status. - IrArray() : base_ptr_(nullptr), shape_(nullptr) {} + IrArray() : base_ptr_(nullptr) {} // Construct an IrArray with the given base pointer and shape. base_ptr is a // pointer type pointing to the first element(lowest address) of the array. - IrArray(llvm::Value* base_ptr, const Shape& shape); + IrArray(llvm::Value* base_ptr, Shape shape); // Default implementations of copying and moving. IrArray(IrArray&& other) = default; @@ -236,7 +243,6 @@ class IrArray { llvm::Type* GetElementLlvmType() const { return element_type_; } const Shape& GetShape() const { - CHECK(shape_ != nullptr); return *shape_; } @@ -331,7 +337,7 @@ class IrArray { llvm::Type* element_type_; // Shape of the XLA array. - const Shape* shape_; + absl::optional shape_; // The list of key/value pairs used when attaching metadata to emitted // loads/stores for this array. They keys are the metadata kinds and the diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h index abc06fb7b4245294df2dc20d25a22ac4fdaeb4cf..02c719502ee7b0a732ae74acec364f89d51ae0c1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h @@ -254,6 +254,11 @@ class IrBuilderMixin { return mixin_builder()->CreateFCmpOLT(std::forward(args)...); } + template + llvm::Value* FCmpOLE(Args&&... args) { + return mixin_builder()->CreateFCmpOLE(std::forward(args)...); + } + template llvm::Value* FCmpONE(Args&&... args) { return mixin_builder()->CreateFCmpONE(std::forward(args)...); @@ -264,6 +269,11 @@ class IrBuilderMixin { return mixin_builder()->CreateFCmpUNE(std::forward(args)...); } + template + llvm::Value* FCmpUNO(Args&&... args) { + return mixin_builder()->CreateFCmpUNO(std::forward(args)...); + } + template llvm::Value* FDiv(Args&&... args) { return mixin_builder()->CreateFDiv(std::forward(args)...); diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index bd0139f85b6a5c5dc23dad962263038451921e65..5eeb29c478a371dae83251771f2dc4844672d3e9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -18,28 +18,29 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { -Status KernelSupportLibrary::For( +Status KernelSupportLibrary::ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - return If(b_->CreateICmpSLT(start, end), [&]() -> Status { + return IfWithStatus(b_->CreateICmpSLT(start, end), [&]() -> Status { TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); - return For(name, b_->CreateAdd(start, step), end, step, - [&](llvm::Value* iv) { return for_body_generator(iv, false); }); + return ForWithStatus( + name, b_->CreateAdd(start, step), end, step, + [&](llvm::Value* iv) { return for_body_generator(iv, false); }); }); } -Status KernelSupportLibrary::For( +Status KernelSupportLibrary::ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function& for_body_generator) { if (peel_first_iteration) { - return For(name, start, end, step, true, - [&](llvm::Value* indvar, bool is_first_iteration) -> Status { - return for_body_generator(indvar, - b_->getInt1(is_first_iteration)); - }); + return ForWithStatus( + name, start, end, step, true, + [&](llvm::Value* indvar, bool is_first_iteration) -> Status { + return for_body_generator(indvar, b_->getInt1(is_first_iteration)); + }); } else { std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( name, start, end, step, b_, @@ -55,7 +56,7 @@ Status KernelSupportLibrary::For( } } -Status KernelSupportLibrary::If( +Status KernelSupportLibrary::IfWithStatus( absl::string_view name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 43fec311f150d6054f6ad24f99db332f90ff94a3..612b839cfa15711061e1ae53358a72d5220e1801 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -48,41 +48,42 @@ class KernelSupportLibrary { // for (i64 i = `start` + `step`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator); - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { CHECK_EQ(Status::OK(), - For(name, start, end, step, + ForWithStatus( + name, start, end, step, [&](llvm::Value* ind_var, bool is_first_iteration) -> Status { for_body_generator(ind_var, is_first_iteration); return Status::OK(); })); } - Status For(absl::string_view name, int64 start, int64 end, int64 step, - const std::function& - for_body_generator) { - return For(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + Status ForWithStatus( + absl::string_view name, int64 start, int64 end, int64 step, + const std::function& for_body_generator) { + return ForWithStatus(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } - void ForReturnVoid( + void For( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + For(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } // Generates the following control flow structure if `peel_first_iteration` is @@ -99,19 +100,19 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, - llvm::Value* step, bool peel_first_iteration, - const std::function& - for_body_generator); + Status ForWithStatus( + absl::string_view name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator); - void ForReturnVoid(absl::string_view name, llvm::Value* start, - llvm::Value* end, llvm::Value* step, - bool peel_first_iteration, - const std::function& - for_body_generator) { - TF_CHECK_OK(For( + void For(absl::string_view name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator) { + TF_CHECK_OK(ForWithStatus( name, start, end, step, peel_first_iteration, [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status { for_body_generator(ind_var, is_first_iteration); @@ -119,80 +120,81 @@ class KernelSupportLibrary { })); } - Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, - int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - return For(name, /*start=*/start, /*end=*/end, - /*step=*/llvm::ConstantInt::get(start->getType(), step), - peel_first_iteration, for_body_generator); + Status ForWithStatus( + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, + bool peel_first_iteration, + const std::function& + for_body_generator) { + return ForWithStatus( + name, /*start=*/start, /*end=*/end, + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } - void ForReturnVoid(absl::string_view name, llvm::Value* start, - llvm::Value* end, int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - ForReturnVoid(name, /*start=*/start, /*end=*/end, - /*step=*/llvm::ConstantInt::get(start->getType(), step), - peel_first_iteration, for_body_generator); + void For(absl::string_view name, llvm::Value* start, llvm::Value* end, + int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + For(name, /*start=*/start, /*end=*/end, + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - return For(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) -> Status { - return for_body_generator(indvar); - }); + return ForWithStatus(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - ForReturnVoid(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) { - return for_body_generator(indvar); - }); + For(name, start, end, step, + /*peel_first_iteration=*/false, [&](llvm::Value* indvar, llvm::Value*) { + return for_body_generator(indvar); + }); } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - return For(name, start, end, llvm::ConstantInt::get(start->getType(), step), - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) -> Status { - return for_body_generator(indvar); - }); + return ForWithStatus(name, start, end, + llvm::ConstantInt::get(start->getType(), step), + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, start, end, - llvm::ConstantInt::get(start->getType(), step), - for_body_generator); + For(name, start, end, llvm::ConstantInt::get(start->getType(), step), + for_body_generator); } - Status For( + Status ForWithStatus( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - return For(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + return ForWithStatus(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } - void ForReturnVoid( + void For( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + For(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } // Generates the following control flow structure: @@ -201,38 +203,43 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - Status If(absl::string_view name, llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = - []() -> Status { return Status::OK(); }); + Status IfWithStatus( + absl::string_view name, llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() -> Status { + return Status::OK(); + }); - Status If(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = - []() -> Status { return Status::OK(); }) { - return If("", condition, true_block_generator, false_block_generator); + Status IfWithStatus( + llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() -> Status { + return Status::OK(); + }) { + return IfWithStatus("", condition, true_block_generator, + false_block_generator); } - void IfReturnVoid(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() { - }) { - IfReturnVoid("", condition, true_block_generator, false_block_generator); + void If( + llvm::Value* condition, const std::function& true_block_generator, + const std::function& false_block_generator = []() {}) { + If("", condition, true_block_generator, false_block_generator); } - void IfReturnVoid(absl::string_view name, llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() { - }) { - TF_CHECK_OK(If(name, condition, - [&]() { - true_block_generator(); - return Status::OK(); - }, - [&]() { - false_block_generator(); - return Status::OK(); - })); + void If( + absl::string_view name, llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() {}) { + TF_CHECK_OK(IfWithStatus( + name, condition, + [&]() { + true_block_generator(); + return Status::OK(); + }, + [&]() { + false_block_generator(); + return Status::OK(); + })); } using ArgumentVector = absl::Span; diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index c26711e526c9b89cdedcb6aed9f93d41dd25dc83..cd8dd72cd775d5e0b52f96a2326367da0775e7eb 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -120,10 +120,11 @@ KernelMappingScheme::KernelMappingScheme( absl::Span req_block_sizes, int64 num_threads_y, int64 num_threads_x, llvm::IRBuilder<>* b) : b_(b), - dims_in_elems_(dims_in_elems), + dims_in_elems_(dims_in_elems.begin(), dims_in_elems.end()), tile_sizes_{1, tile_size_y, tile_size_x}, num_threads_x_(num_threads_x), - num_threads_y_(num_threads_y) { + num_threads_y_(num_threads_y), + dilated_x_(true) { DCHECK_EQ(dims_in_elems_.size(), 3); DCHECK_EQ(req_block_sizes.size(), 3); @@ -170,14 +171,16 @@ IrArray::Index KernelMappingScheme::EmitBlockIndex(llvm::Type* index_ty) { IrArray::Index KernelMappingScheme::GetTileIndexForBlockOrigin( const IrArray::Index& block_index) { - IrArray::Index tile_index = block_index; + DCHECK_EQ(block_index.size(), block_sizes_.size()); + std::vector multidim; + multidim.reserve(block_sizes_.size()); for (int i = 0; i < block_sizes_.size(); ++i) { - tile_index[i] = b_->CreateMul( + multidim.push_back(b_->CreateMul( block_index[i], llvm::ConstantInt::get(block_index[i]->getType(), block_sizes_[i]), - "block_origin." + std::to_string(i)); + "block_origin." + std::to_string(i))); } - return tile_index; + return IrArray::Index(multidim, block_index[0]->getType()); } IrArray::Index KernelMappingScheme::GetElementIndexForTileOrigin( @@ -217,14 +220,14 @@ KernelMappingScheme::EmitThreadYXCoordinate(llvm::Type* index_ty) { // defined by (num_thread_y, num_thread_x) from thread_id. llvm::CallInst* thread_id_raw = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); - llvm_ir::AddRangeMetadata(0, GetThreadsPerTile(), thread_id_raw); + llvm_ir::AddRangeMetadata(0, GetThreadsPerBlock(), thread_id_raw); llvm::Value* thread_id_int = b_->CreateIntCast(thread_id_raw, index_ty, /*isSigned=*/true, "thread.id.x"); llvm::Value* num_thread_x = llvm::ConstantInt::get(index_ty, GetNumberOfThreadsForDimensionX()); - llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x); - llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x); + llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x, "thread.x"); + llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x, "thread.y"); return std::make_tuple(y, x); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index 06002d57b0d7daa07f903feebe67a60a083c0e7c..f802cc27d519e621262f328903697373aa8c284c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -90,15 +90,16 @@ class KernelMappingScheme { enum { DimZ = 0, DimY, DimX, DimTot }; public: + KernelMappingScheme() {} // dims_in_elems: the normalized tensor dimensions. // req_block_sizes: the requested block size in number of tiles for each // dimension. The actual block size is set to min(req_block_size, // dims_in_number_of_blocks). - explicit KernelMappingScheme(absl::Span dims_in_elems, - int64 tile_size_y, int64 tile_size_x, - absl::Span req_block_sizes, - int64 num_threads_y, int64 num_threads_x, - llvm::IRBuilder<>* b); + KernelMappingScheme(absl::Span dims_in_elems, int64 tile_size_y, + int64 tile_size_x, + absl::Span req_block_sizes, + int64 num_threads_y, int64 num_threads_x, + llvm::IRBuilder<>* b); absl::Span GetDimensionsInElements() const { return dims_in_elems_; @@ -116,7 +117,10 @@ class KernelMappingScheme { int64 GetNumberOfTilesInOneBlock() const { return absl::c_accumulate(block_sizes_, 1, std::multiplies()); } - + int64 GetNumberOfTilesInOneBlockForDimension(int d) const { + DCHECK(d >= DimZ && d <= DimX); + return block_sizes_[d]; + } int64 GetNumberOfBlocks() const { return absl::c_accumulate(dims_in_blocks_, 1, std::multiplies()); } @@ -133,15 +137,29 @@ class KernelMappingScheme { } absl::Span GetBlockSizes() const { return block_sizes_; } + int64 GetTileBlockSizeForDimension(int d) const { + DCHECK(d >= DimZ && d <= DimX); + return dims_in_blocks_[d]; + } int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; } int64 GetNumberOfThreadsForDimensionY() const { return num_threads_y_; } - int64 GetThreadsPerTile() const { + int64 GetThreadsPerBlock() const { return GetNumberOfThreadsForDimensionX() * GetNumberOfThreadsForDimensionY(); } + bool DilatedX() const { return dilated_x_; } + void SetDilatedX(bool v) { + dilated_x_ = v; + if (!dilated_x_) { + // dilated_x_=false is for the purpose of vectorization, which requires + // GetTileSizeForDimension(DimX) to be a multiplier of num_threads_x_. + CHECK_EQ(GetTileSizeForDimension(DimX) % num_threads_x_, 0); + } + } + IrArray::Index EmitBlockIndex(llvm::Type* index_ty); // Returns the index for the first tile in the block with the given block // index. @@ -163,7 +181,7 @@ class KernelMappingScheme { private: llvm::IRBuilder<>* b_; // The number of elements in each dimension. - absl::Span dims_in_elems_; + std::vector dims_in_elems_; // The number of elements for each dimension of a tile. std::vector tile_sizes_; @@ -181,6 +199,13 @@ class KernelMappingScheme { int64 num_threads_x_; // Number of threads used to process elements in the Y direction of a tile. int64 num_threads_y_; + + // When num_threads_x threads process a total of tile_size_x elements in the + // X dimension of a tile, each threads process n=tile_size_x/num_threads_x + // elements. When dilated_x=false, the n elements processed by a thread are + // contiguous. On the other hand, when dilated_x=true the n elements are + // dilated by a factor of num_threads_x. + bool dilated_x_; }; // A class to represent information for tiled parameters to support IR emission diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 219a9f221fbd116cdfbaf17985e21d82aefd079d..3a35405a2da0af386e01bb48bed56ad194048543 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -235,7 +234,7 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, absl::string_view suffix) { - std::vector dimensions(ShapeUtil::Rank(shape)); + std::vector dimensions(shape.rank()); std::iota(dimensions.begin(), dimensions.end(), 0); return AddLoopsForShapeOnDimensions(shape, dimensions, suffix); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ceea24685af566e02340664f0a40c398c62b5ab0..807296329c07b8e4ac630486a1e1f59e4fdfa009 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -188,7 +188,16 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, } return cplx_t; } - // A Tuple contains an array of pointers. Use i8*. + case C128: { + auto cplx_t = module->getTypeByName("complex128"); + if (cplx_t == nullptr) { + return llvm::StructType::create( + {llvm::Type::getDoubleTy(module->getContext()), + llvm::Type::getDoubleTy(module->getContext())}, + "complex128", /*isPacked=*/true); + } + return cplx_t; + } // A Tuple contains an array of pointers. Use i8*. case TUPLE: // An Opaque is like a void*, use i8*. case OPAQUE: @@ -219,10 +228,10 @@ int GetSizeInBits(llvm::Type* type) { llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module); - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { // A tuple buffer is an array of pointers. result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size()); - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { for (int64 dimension : LayoutUtil::MinorToMajor(shape)) { result_type = llvm::ArrayType::get(result_type, shape.dimensions(dimension)); @@ -621,6 +630,10 @@ llvm::Function* CreateFunction(llvm::FunctionType* function_type, function->setCallingConv(llvm::CallingConv::C); function->addFnAttr("no-frame-pointer-elim", "false"); + // Generate unwind information so that GDB can crawl through the stack frames + // created by the JIT compiled code. + function->setHasUWTable(); + if (enable_fast_math) { function->addFnAttr("unsafe-fp-math", "true"); function->addFnAttr("no-infs-fp-math", "true"); diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 0dc120e0b0df47f261435f490a8459b49d989b53..a689881e65ec3a7ddf606c36bdd64b749cfe358e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index e22c2173c271fc9571be1ddb0759d2b31562dc98..d71addec9b7317dfe16e9d7e5380c3cfda0b8c06 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -45,13 +46,14 @@ namespace llvm_ir { namespace { // Adds the inner comparison loop body where we compare elements. -void EmitCompareLoopBody( - int64 iteration_bound, PrimitiveType key_type, int64 num_values, - int64 iota_values_parameter_index, llvm::Value* element_pair_index, +Status EmitCompareLoopBody( + int64 iteration_bound, int64 num_values, llvm::Value* element_pair_index, int64 xor_mask, llvm::Type* index_type, - std::function read_element, + std::function + element_address, std::function write_element, + const EmitCallToNestedComputationCallback& emit_compare_callback, llvm::IRBuilder<>* b, bool needs_bounds_checks = true) { auto index_typed_constant = [&](int64 value) { return llvm::ConstantInt::get(index_type, value); @@ -108,74 +110,44 @@ void EmitCompareLoopBody( // if (is_smaller_index && index_is_inbounds) KernelSupportLibrary ksl(b); - ksl.IfReturnVoid("smaller_comparison_index", do_comparison, [&]() { - auto key1 = read_element(0, current_keys_index); - auto key2 = read_element(0, compare_keys_index); - auto compare_key1 = key1; - auto compare_key2 = key2; - bool is_signed_comparison = true; - if (primitive_util::IsFloatingPointType(key_type)) { - // We would like a total order of floating point numbers so that the - // sort has a predictable behavior in the presence of NaNs. Rather - // than using floating point comparison, we use the following trick: - // If f is a float, and - // x = bit_cast(f); - // y = x < 0 ? 0x7FFFFFFF - x : x; - // then y is ordered as an int32 such that finite values have the - // obvious order, -0 is ordered before 0, and -NaN and NaN appear at - // the beginning and end of the ordering. - auto k = b->getInt(llvm::APInt::getSignedMaxValue( - key1->getType()->getPrimitiveSizeInBits())); - auto comparison_type = k->getType(); - auto zero = llvm::ConstantInt::get(comparison_type, 0); - auto maybe_flip = [&](llvm::Value* v) { - return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), - b->CreateSub(k, v), v); - }; - compare_key1 = b->CreateBitCast(key1, comparison_type); - compare_key2 = b->CreateBitCast(key2, comparison_type); - compare_key1 = maybe_flip(compare_key1); - compare_key2 = maybe_flip(compare_key2); - } else if (!primitive_util::IsSignedIntegralType(key_type)) { - is_signed_comparison = false; - } - // If key2 < key1 - auto is_smaller_than = - b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT - : llvm::ICmpInst::ICMP_ULT, - compare_key2, compare_key1); - if (iota_values_parameter_index >= 0) { - auto keys_equal = b->CreateICmpEQ(compare_key1, compare_key2); - auto key_index1 = - read_element(iota_values_parameter_index, current_keys_index); - auto key_index2 = - read_element(iota_values_parameter_index, compare_keys_index); - auto index_is_smaller_than = - b->CreateICmp(llvm::ICmpInst::ICMP_ULT, key_index2, key_index1); - is_smaller_than = b->CreateOr( - is_smaller_than, b->CreateAnd(keys_equal, index_is_smaller_than)); + return ksl.IfWithStatus("smaller_comparison_index", do_comparison, [&]() { + std::vector values_to_compare; + for (int i = 0; i < num_values; ++i) { + values_to_compare.push_back(element_address(i, compare_keys_index)); + values_to_compare.push_back(element_address(i, current_keys_index)); } - ksl.IfReturnVoid("is_smaller_than", is_smaller_than, [&]() { - // Swap key1 with key2. - write_element(0, current_keys_index, key2); - write_element(0, compare_keys_index, key1); - for (int64 i = 1; i <= num_values; ++i) { - // Also swap the values. - auto value1 = read_element(i, current_keys_index); - auto value2 = read_element(i, compare_keys_index); - write_element(i, current_keys_index, value2); - write_element(i, compare_keys_index, value1); + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + llvm::Value* compare_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(PRED, module), "compare_return_buffer", + b); + TF_RETURN_IF_ERROR( + emit_compare_callback(values_to_compare, compare_return_buffer)); + llvm::Value* result = b->CreateLoad(compare_return_buffer); + + // Check if the 'compare' function returns true. + llvm::Value* is_smaller_than = + b->CreateICmpNE(result, llvm::ConstantInt::get(result->getType(), 0), + "boolean_predicate"); + ksl.If("is_smaller_than", is_smaller_than, [&]() { + for (int64 i = 0; i < num_values; ++i) { + // Swap the values. + auto value1 = b->CreateLoad(values_to_compare[i * 2]); + auto value2 = b->CreateLoad(values_to_compare[i * 2 + 1]); + write_element(i, current_keys_index, value1); + write_element(i, compare_keys_index, value2); } }); + return Status::OK(); }); } -void EmitTiledCompareLoop( +Status EmitTiledCompareLoop( const IrArray::Index& tiled_keys_index, int64 dimension_to_sort, - int64 dimension_to_sort_bound, PrimitiveType keys_type, - absl::Span xor_masks, const std::vector& params, - const std::vector& param_shmem_buffers, - int64 iota_values_parameter_index, int64 tile_size, llvm::IRBuilder<>* b) { + int64 dimension_to_sort_bound, absl::Span xor_masks, + const std::vector& params, + const std::vector& param_shmem_buffers, int64 tile_size, + const EmitCallToNestedComputationCallback& emit_compare_callback, + llvm::IRBuilder<>* b) { KernelSupportLibrary ksl(b); llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b); @@ -192,7 +164,7 @@ void EmitTiledCompareLoop( b->CreateShl(tiled_keys_index[dimension_to_sort], value_one); // We want to copy two adjacent elements. We first check whether the // first index position is within bounds. - ksl.IfReturnVoid( + ksl.If( "smaller_keys_index", b->CreateICmpSLT(current_keys_index, tiled_keys_index.GetConstantWithIndexType( @@ -200,18 +172,17 @@ void EmitTiledCompareLoop( [&]() { auto cache_index = b->CreateShl(thread_id, value_one); read_or_write(cache_index, current_keys_index); - // Increment to go the next index position. + // Increment to go to the next index position. current_keys_index = b->CreateAdd(current_keys_index, value_one); // Here we check whether the next index position is within bounds. - ksl.IfReturnVoid( - "inner_smaller_keys_index", - b->CreateICmpSLT(current_keys_index, - tiled_keys_index.GetConstantWithIndexType( - dimension_to_sort_bound)), - [&]() { - cache_index = b->CreateAdd(cache_index, value_one); - read_or_write(cache_index, current_keys_index); - }); + ksl.If("inner_smaller_keys_index", + b->CreateICmpSLT(current_keys_index, + tiled_keys_index.GetConstantWithIndexType( + dimension_to_sort_bound)), + [&]() { + cache_index = b->CreateAdd(cache_index, value_one); + read_or_write(cache_index, current_keys_index); + }); }); }; @@ -231,10 +202,18 @@ void EmitTiledCompareLoop( llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); // Now emit the bodies of the comparison loops. - auto read_element = [&](int64 operand, llvm::Value* index) { - return b->CreateLoad( + auto element_address = [&](int64 operand, llvm::Value* index) { + auto shared_memory_address = b->CreateGEP(param_shmem_buffers[operand], - {tiled_keys_index.GetConstantWithIndexType(0), index})); + {tiled_keys_index.GetConstantWithIndexType(0), index}); + auto ptr_type = shared_memory_address->getType(); + // We need a generic pointer with address space 0 instead of a pointer to + // shared memory (address space 3) so that we can pass it to the comparison + // computation. + return b->CreateAddrSpaceCast( + shared_memory_address, + llvm::PointerType::get(ptr_type->getPointerElementType(), + /*AddressSpace=*/0)); }; auto write_element = [&](int64 operand, llvm::Value* index, llvm::Value* value) { @@ -253,7 +232,7 @@ void EmitTiledCompareLoop( if (dimension_to_sort_bound % tile_size) { // Otherwise we need a bounds check for the last tile. The last tile has // size 'dimension_to_sort_bound' % 'tile_size'. - ksl.IfReturnVoid( + TF_RETURN_IF_ERROR(ksl.IfWithStatus( "is_last_tile", b->CreateICmpUGE( b->CreateMul(tiled_keys_index[dimension_to_sort], @@ -261,24 +240,24 @@ void EmitTiledCompareLoop( tiled_keys_index.GetConstantWithIndexType( RoundDownToNearest(dimension_to_sort_bound, tile_size))), [&]() { - EmitCompareLoopBody(dimension_to_sort_bound % tile_size, keys_type, - params.size() - 1, iota_values_parameter_index, - element_pair_index, xor_mask, - tiled_keys_index.GetType(), read_element, - write_element, b); + return EmitCompareLoopBody( + dimension_to_sort_bound % tile_size, params.size(), + element_pair_index, xor_mask, tiled_keys_index.GetType(), + element_address, write_element, emit_compare_callback, b); }, [&]() { - EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, - iota_values_parameter_index, element_pair_index, - xor_mask, tiled_keys_index.GetType(), - read_element, write_element, b, - /*needs_bounds_checks=*/false); - }); + return EmitCompareLoopBody( + tile_size, params.size(), element_pair_index, xor_mask, + tiled_keys_index.GetType(), element_address, write_element, + emit_compare_callback, b, + /*needs_bounds_checks=*/false); + })); } else { - EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, - iota_values_parameter_index, element_pair_index, - xor_mask, tiled_keys_index.GetType(), read_element, - write_element, b, /*needs_bounds_checks=*/false); + TF_RETURN_IF_ERROR(EmitCompareLoopBody( + tile_size, params.size(), element_pair_index, xor_mask, + tiled_keys_index.GetType(), element_address, write_element, + emit_compare_callback, b, + /*needs_bounds_checks=*/false)); } // Wait until all comparisons have happened. llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); @@ -302,17 +281,16 @@ void EmitTiledCompareLoop( // same location in shared memory because we have exactly tile_size / 2 many // threads, and the linear index calculated by ParallelLoopEmitter uses // linear_index = blockIdx.x * blockDim.x + threadIdx.x; + return Status::OK(); } } // namespace -Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const std::vector& values_arrays, - int64 iota_values_parameter_index, - absl::string_view name, - absl::Span xor_masks, llvm::IRBuilder<>* b, - const gpu::LaunchDimensions& launch_dimensions, - int64 num_iterations_in_sort_dim, - const int64 tile_size) { +Status EmitSortInPlace( + int64 dimension_to_sort, const std::vector& values_arrays, + absl::string_view name, absl::Span xor_masks, + llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions, + int64 num_iterations_in_sort_dim, const int64 tile_size, + const EmitCallToNestedComputationCallback& emit_compare_callback) { // Iterate through the keys shape in physical order, but skip the dimension to // sort and make it the innermost loop which is the loop where the comparisons // happen. In the dimension to sort, if we use tiling, we iterate through it @@ -322,8 +300,8 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, // within those 64 elements and are therefore independent of the other // comparisons). - const Shape& keys_shape = keys_array.GetShape(); - int64 rank = ShapeUtil::Rank(keys_shape); + const Shape& keys_shape = values_arrays[0].GetShape(); + int64 rank = keys_shape.rank(); int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); std::vector dimensions_in_iteration_order(rank); std::vector iteration_order_to_logical_order(rank); @@ -339,18 +317,16 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(), dimensions_in_iteration_order); - std::vector params(1, keys_array); - params.insert(params.end(), values_arrays.begin(), values_arrays.end()); // Allocate shared memory for the tiled compare loop. - std::vector param_shmem_buffers(params.size(), nullptr); + std::vector param_shmem_buffers(values_arrays.size(), nullptr); if (xor_masks.size() > 1) { llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); - for (int64 i = 0; i < params.size(); ++i) { - llvm::Type* tile_type = - llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( - params[i].GetShape().element_type(), module), - tile_size); + for (int64 i = 0; i < values_arrays.size(); ++i) { + llvm::Type* tile_type = llvm::ArrayType::get( + llvm_ir::PrimitiveTypeToIrType( + values_arrays[i].GetShape().element_type(), module), + tile_size); param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile( module, tile_type, absl::StrCat(name, "_tile_param_", i)); } @@ -377,25 +353,24 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, keys_index[iteration_order_to_logical_order[i]] = tiles_index[i]; } if (xor_masks.size() > 1) { - EmitTiledCompareLoop(keys_index, dimension_to_sort, - dimension_to_sort_bound, keys_shape.element_type(), - xor_masks, params, param_shmem_buffers, - iota_values_parameter_index, tile_size, b); + TF_RETURN_IF_ERROR(EmitTiledCompareLoop( + keys_index, dimension_to_sort, dimension_to_sort_bound, xor_masks, + values_arrays, param_shmem_buffers, tile_size, emit_compare_callback, + b)); } else { - auto read_element = [&](int64 operand, llvm::Value* index) { + auto element_address = [&](int64 operand, llvm::Value* index) { keys_index[dimension_to_sort] = index; - return params[operand].EmitReadArrayElement(keys_index, b); + return values_arrays[operand].EmitArrayElementAddress(keys_index, b); }; auto write_element = [&](int64 operand, llvm::Value* index, llvm::Value* value) { keys_index[dimension_to_sort] = index; - params[operand].EmitWriteArrayElement(keys_index, value, b); + values_arrays[operand].EmitWriteArrayElement(keys_index, value, b); }; - EmitCompareLoopBody(dimension_to_sort_bound, keys_shape.element_type(), - values_arrays.size(), iota_values_parameter_index, - tiles_index[rank - 1], xor_masks[0], - tiles_index.GetType(), read_element, write_element, - b); + TF_RETURN_IF_ERROR(EmitCompareLoopBody( + dimension_to_sort_bound, values_arrays.size(), tiles_index[rank - 1], + xor_masks[0], tiles_index.GetType(), element_address, write_element, + emit_compare_callback, b)); } return Status::OK(); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 685f9383acba416f51681270e4037d56abb4b6ea..b9341a34d1f2203db6e02c3df5d607174b6d0f74 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -28,19 +28,18 @@ limitations under the License. namespace xla { namespace llvm_ir { +using EmitCallToNestedComputationCallback = + std::function, llvm::Value*)>; // Emits llvm IR to do pairwise comparisons/swaps in the 'dimension_to_sort' -// dimension of 'keys_array'. All other dimensions are kept as-is. This -// implements the inner loop of BitonicSort. It is assumed that 'xor_masks' -// contains only powers of 2, or values 2^k - 1 (k > 0). If -// 'iota_values_parameter_index' is >= 0, it points at a 'values_arrays' operand -// that is a iota and can be used to make the sorting stable. -Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const std::vector& values_arrays, - int64 iota_values_parameter_index, - absl::string_view name, - absl::Span xor_masks, llvm::IRBuilder<>* b, - const gpu::LaunchDimensions& launch_dimensions, - int64 num_iterations_in_sort_dim, int64 tile_size); +// dimension of each array in 'values_arrays'. All other dimensions are kept +// as-is. This implements the inner loop of BitonicSort. It is assumed that +// 'xor_masks' contains only powers of 2, or values 2^k - 1 (k > 0). +Status EmitSortInPlace( + int64 dimension_to_sort, const std::vector& values_arrays, + absl::string_view name, absl::Span xor_masks, + llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions, + int64 num_iterations_in_sort_dim, int64 tile_size, + const EmitCallToNestedComputationCallback& emit_compare_callback); } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index a60643bc754f896d096b3ca4e1216e77d7e384c6..d8d2700e1934fd202d44a1dc60e71a99913d4537 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -93,7 +93,7 @@ llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, llvm::LoadInst* src_buffer = b->CreateLoad(element_ptr); // Mark the loaded pointer as dereferenceable if we know its shape. - if (!ShapeUtil::IsOpaque(target_shape)) { + if (!target_shape.IsOpaque()) { SetDereferenceableMetadataForLoad( src_buffer, ByteSizeOf(target_shape, src_buffer->getModule()->getDataLayout())); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 6c89700983363fec46c41b5430c6eab6b366a1b6..3470fe5b2c34bf832207ed546fad176319446f31 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -52,8 +52,10 @@ namespace xla { } BackendOptions backend_options; - backend_options.set_platform(platform).set_intra_op_parallelism_threads( - options.intra_op_parallelism_threads()); + backend_options.set_platform(platform) + .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads()) + .set_allowed_devices(options.allowed_devices()); + TF_ASSIGN_OR_RETURN(std::unique_ptr backend, Backend::CreateBackend(backend_options)); @@ -108,6 +110,7 @@ ExecutionOptions CreateExecutionOptions( *execution_options.mutable_shape_with_output_layout() = result_shape.ToProto(); } + execution_options.set_num_replicas(build_options.num_replicas()); return execution_options; } diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 9ccdd7d8d818b9fa3aa77cdd10d37ca18928b448..53d52d9a3d918fa6dee093668923fcfff963d084 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -198,7 +198,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) { continue; } - if (in_list.count(instr) > 0) { + if (in_list.contains(instr)) { continue; } int64 profit = GetProfit(instr, fusion); diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index ac2f79674feceff436c0e9c65338967f498e4473..e55b83d17e90bc2ca0053a0421cf80ef6edd5bca 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -15,8 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -27,13 +29,13 @@ namespace { bool IsAllowed(char character) { auto c = static_cast(character); - return (isalnum(c) != 0) || c == '_' || c == '.' || c == '-'; + return (absl::ascii_isalnum(c) != 0) || c == '_' || c == '.' || c == '-'; } } // namespace NameUniquer::NameUniquer(const string& separator) { - CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed)) + CHECK(absl::c_all_of(separator, IsAllowed)) << "separator should comprises allowed characters only"; separator_ = separator; } @@ -42,9 +44,10 @@ NameUniquer::NameUniquer(const string& separator) { if (name.empty()) { return ""; } + string result = name; char c = static_cast(result[0]); - if (!isalpha(c) && c != '_') { + if (!absl::ascii_isalpha(c) && c != '_') { result[0] = '_'; } for (int i = 1; i < result.length(); i++) { @@ -52,6 +55,13 @@ NameUniquer::NameUniquer(const string& separator) { result[i] = '_'; } } + + // HLO primitive type names (with the exception of 'tuple') are keywords in + // the HLO text representation and cannot be names, so append an underscore if + // the name is a primitive type. + if (primitive_util::IsPrimitiveTypeName(result) && result != "tuple") { + result += "_"; + } return result; } diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 3e2592c6ac626143f1421e545a31d9be91e376bc..d0d04147e0c29c66cba447550c0a9c703f35573a 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -104,5 +104,21 @@ TEST_F(NameUniquerTest, KeepNamesInRandomOrder) { EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo.3")); } +TEST_F(NameUniquerTest, AvoidKeywords) { + NameUniquer uniquer("."); + + EXPECT_EQ("f32_", uniquer.GetUniqueName("f32")); + EXPECT_EQ("s64_", uniquer.GetUniqueName("s64")); + EXPECT_EQ("pred_", uniquer.GetUniqueName("pred")); + + // Though a primitive type, "tuple" is not a keyword. + EXPECT_EQ("tuple", uniquer.GetUniqueName("tuple")); + + // Keywords are not capitalized. + EXPECT_EQ("F32", uniquer.GetUniqueName("F32")); + EXPECT_EQ("S32", uniquer.GetUniqueName("S32")); + EXPECT_EQ("Pred", uniquer.GetUniqueName("Pred")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/op_expander_pass.cc b/tensorflow/compiler/xla/service/op_expander_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..02c9d4b387b112be39c204d35fe4fa1013ed064c --- /dev/null +++ b/tensorflow/compiler/xla/service/op_expander_pass.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +StatusOr OpExpanderPass::Run(HloModule* module) { + std::vector matching_instructions; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + absl::c_copy_if( + computation->instructions(), std::back_inserter(matching_instructions), + [&](HloInstruction* inst) { return InstructionMatchesPattern(inst); }); + } + + for (HloInstruction* inst : matching_instructions) { + TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, + ExpandInstruction(inst)); + if (expanded_root == nullptr) { + continue; + } + TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); + } + + return !matching_instructions.empty(); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/op_expander_pass.h b/tensorflow/compiler/xla/service/op_expander_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..276e3d70b8ecd8742e0b277698765063198fe872 --- /dev/null +++ b/tensorflow/compiler/xla/service/op_expander_pass.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_OP_EXPANDER_PASS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_OP_EXPANDER_PASS_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This pass is an abstract superclass for passes that replace operations that +// match a pattern. It is intended to be subclassed, not used directly. +// +// This pass is useful for legalizing HLO instructions that a particular backend +// does not support into other HLO instructions. +class OpExpanderPass : public HloModulePass { + public: + StatusOr Run(HloModule* module) override; + + protected: + // Returns `true` if `instruction` should be expanded by this pass. + virtual bool InstructionMatchesPattern(HloInstruction* instruction) = 0; + + // Returns a replacement for `instruction`, or nullptr if no replacement is + // neeeded (e.g. only the to_apply subcomputation of the instruction was + // modified). + virtual StatusOr ExpandInstruction( + HloInstruction* instruction) = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_OP_EXPANDER_PASS_H_ diff --git a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc new file mode 100644 index 0000000000000000000000000000000000000000..701c629add52a217f16877a085b9ef2d096623d9 --- /dev/null +++ b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc @@ -0,0 +1,106 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// Returns true if the given shape is a non-nested tuple. +bool IsNonNestedTuple(const Shape& shape) { + return shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape); +} + +} // namespace + +StatusOr OptimizeInputOutputBufferAlias::Build( + const Shape& input_shape, const Shape& output_shape, + HloInputOutputAliasConfig* alias_config) { + bool changed = false; + TF_RET_CHECK(LayoutUtil::HasLayout(input_shape)); + TF_RET_CHECK(LayoutUtil::HasLayout(output_shape)); + VLOG(1) << "input_shape:" << input_shape.ToString(); + VLOG(1) << "output_shape:" << output_shape.ToString(); + + // For all buffers defined by the parameter, build a map from the byte + // size to the list of the buffers of that size. + absl::flat_hash_map> size_to_input_index; + ShapeUtil::ForEachSubshape( + input_shape, [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsTuple()) { + return; + } + int64 bytes = size_func_(subshape); + size_to_input_index[bytes].push(index); + }); + + // For each result buffer shape index, take the first unused parameter + // buffer that matches the size. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + output_shape, [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsTuple()) { + return Status::OK(); + } + int64 bytes = size_func_(subshape); + + auto it = size_to_input_index.find(bytes); + if (it != size_to_input_index.end() && !it->second.empty()) { + changed = true; + const ShapeIndex& input_index = it->second.front(); + const ShapeIndex& output_index = index; + if (!alias_config->ParameterHasAlias(0, input_index) && + !alias_config->OutputHasAlias(output_index)) { + TF_RETURN_IF_ERROR(alias_config->SetUpAlias( + output_index, 0, input_index, + HloInputOutputAliasConfig::AliasKind::kSystemAlias)); + } + VLOG(3) << "Set up alias from with param index " + << it->second.front().ToString() << ", shape size " << bytes + << " and result subshape " + << ShapeUtil::HumanStringWithLayout(subshape) << " at index " + << index.ToString(); + it->second.pop(); + } + return Status::OK(); + })); + return changed; +} + +StatusOr OptimizeInputOutputBufferAlias::Run(HloModule* module) { + // User buffer alias only work for modules with 1 parameter. + if (module->entry_computation()->num_parameters() != 1) { + return false; + } + + HloInputOutputAliasConfig* alias_config = + &module->input_output_alias_config(); + + return Build(module->entry_computation()->parameter_instruction(0)->shape(), + module->entry_computation()->root_instruction()->shape(), + alias_config); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h new file mode 100644 index 0000000000000000000000000000000000000000..79ce468e975300ed703ae0fd780f4b9d5328a4b3 --- /dev/null +++ b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// This pass opportunistically finds input and output buffers that can be +// aliased, and writes the alias config into the HloModule. +// +// The input and the output buffers can be in any shape, and each output buffer +// can alias with an input buffer with the same size. Each input buffer may only +// alias with a single output buffer. For example, for the following parameter +// and the output buffers, +// +// Parameters : { P1(2MiB), P2(4MiB), P3(8MiB), P4(4MiB), P5(4MiB), ... } +// Outputs : { O1(4MiB), O2(2MiB), O3(4MiB), O4(6MiB), O5(4MiB), ... } +// +// one potential aliasing would be (O1, P2), (O2, P1), (O3, P4), (O5, P5), .. +class OptimizeInputOutputBufferAlias : public HloModulePass { + using ShapeSizeFunction = std::function; + + public: + OptimizeInputOutputBufferAlias(ShapeSizeFunction size_func) + : size_func_(size_func) {} + ~OptimizeInputOutputBufferAlias() override = default; + + absl::string_view name() const override { + return "optimize_input_output_buffer_alias.h"; + } + + StatusOr Run(HloModule* module) override; + + private: + friend class OptimizeInputOutputBufferAliasTest; + + StatusOr Build(const Shape& input_shape, const Shape& output_shape, + HloInputOutputAliasConfig* alias_config); + ShapeSizeFunction size_func_ = nullptr; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ diff --git a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias_test.cc b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..41e90f9b6931619fd9824e2eda25e12e4c7197b0 --- /dev/null +++ b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h" + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +// Tests that UserBufferAlias properly maps input and output buffer indices of +// various shapes for aliasing. +class OptimizeInputOutputBufferAliasTest : public HloTestBase { + protected: + OptimizeInputOutputBufferAliasTest() { + r1f32_ = ShapeUtil::MakeShape(F32, {4}); + r2f32_ = ShapeUtil::MakeShape(F32, {4, 5}); + r3f32_ = ShapeUtil::MakeShape(F32, {4, 5, 6}); + r4f32_ = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); + + auto size_func = [](const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape); + }; + + optimize_pass_ = + absl::make_unique(size_func); + } + + // Returns the number of output indices that aliases with the input. + int64 AliasCount() { + int64 count = 0; + + config_.ForEachAlias( + [&](const ShapeIndex&, const HloInputOutputAliasConfig::Alias&) { + count++; + }); + return count; + } + + bool BuildAliasConfig(const Shape& input_shape, const Shape& output_shape) { + config_ = HloInputOutputAliasConfig(output_shape); + auto changed = optimize_pass_->Build(input_shape, output_shape, &config_); + TF_CHECK_OK(changed.status()); + + return changed.ValueOrDie(); + } + + std::unique_ptr optimize_pass_; + + HloInputOutputAliasConfig config_; + + Shape r1f32_; + Shape r2f32_; + Shape r3f32_; + Shape r4f32_; +}; + +// All shapes are different, so no aliasing is available. +TEST_F(OptimizeInputOutputBufferAliasTest, AllDifferentBufferSizes) { + Shape input = ShapeUtil::MakeTupleShape({r1f32_, r2f32_}); + Shape output = ShapeUtil::MakeTupleShape({r3f32_, r4f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_FALSE(changed); + EXPECT_EQ(AliasCount(), 0); +} + +// Input and output shapes are equal, so buffers can alias at the same index. +TEST_F(OptimizeInputOutputBufferAliasTest, OrderedNonNestedTuple) { + Shape input = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_}); + Shape output = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_TRUE(changed); + EXPECT_EQ(AliasCount(), 4); + + EXPECT_EQ(config_.GetAliasedOutput(0, {0}), ShapeIndex{0}); + EXPECT_EQ(config_.GetAliasedOutput(0, {1}), ShapeIndex{1}); + EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex{2}); + EXPECT_EQ(config_.GetAliasedOutput(0, {3}), ShapeIndex{3}); +} + +// Only a subset of the tuple element shapes match between the input and the +// output. +TEST_F(OptimizeInputOutputBufferAliasTest, PartialReuseNonNestedTuple) { + Shape input = ShapeUtil::MakeTupleShape({r1f32_, r1f32_, r2f32_, r2f32_}); + Shape output = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_TRUE(changed); + + EXPECT_EQ(AliasCount(), 2); + + EXPECT_EQ(config_.GetAliasedOutput(0, {0}), ShapeIndex{0}); + EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex{1}); +} + +// The output shape is reverse of the input shape, but we can still reuse all +// the buffers. +TEST_F(OptimizeInputOutputBufferAliasTest, UnorderedNonNestedTuple) { + Shape input = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_}); + Shape output = ShapeUtil::MakeTupleShape({r4f32_, r3f32_, r2f32_, r1f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_TRUE(changed); + + EXPECT_EQ(AliasCount(), 4); + + EXPECT_EQ(config_.GetAliasedOutput(0, {0}), ShapeIndex{3}); + EXPECT_EQ(config_.GetAliasedOutput(0, {1}), ShapeIndex{2}); + EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex{1}); + EXPECT_EQ(config_.GetAliasedOutput(0, {3}), ShapeIndex{0}); +} + +TEST_F(OptimizeInputOutputBufferAliasTest, UnorderedNestedTuple) { + Shape input = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({r1f32_}), r2f32_, r3f32_, r4f32_}); + Shape output = ShapeUtil::MakeTupleShape( + {r1f32_, ShapeUtil::MakeTupleShape({r3f32_, r2f32_}), r2f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_TRUE(changed); + + EXPECT_EQ(AliasCount(), 3); + + EXPECT_EQ(config_.GetAliasedOutput(0, {0, 0}), ShapeIndex{0}); + EXPECT_EQ(config_.GetAliasedOutput(0, {1}), ShapeIndex({1, 1})); + EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex({1, 0})); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index c35f72699bfe90f7b8021916c0f81d5e1926ff4c..7164bfc4cd48ea945519dadece92d8df2e88d02a 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -775,7 +775,7 @@ class ShapePatternIsArrayImpl { explicit constexpr ShapePatternIsArrayImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - if (!ShapeUtil::IsArray(*shape)) { + if (!shape->IsArray()) { EXPLAIN << "Shape is not an array"; return false; } @@ -793,7 +793,7 @@ class ShapePatternIsTupleImpl { explicit constexpr ShapePatternIsTupleImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - if (!ShapeUtil::IsTuple(*shape)) { + if (!shape->IsTuple()) { EXPLAIN << "Shape is not a tuple"; return false; } @@ -831,7 +831,7 @@ class ShapePatternRankImpl { explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - if (ShapeUtil::Rank(*shape) != rank_) { + if (shape->rank() != rank_) { if (rank_ == 0) { EXPLAIN << "Shape is not a scalar"; } else { @@ -1737,7 +1737,8 @@ class HloConstantScalarImpl { literal_r0_as_val_ty_or.ValueOrDie() == val_literal && literal_r0 == val_as_literal_ty; if (!rv) { - EXPLAIN << "HloInstruction's constant value " << literal_r0.ToString() + EXPLAIN << "HloInstruction's constant value " + << literal_r0.ToStringWithoutShape() << " did not match expected value " << *val_; } return rv; @@ -1877,7 +1878,7 @@ class HloInstructionPattern { // Make this a templated function to work around gcc 4.9.4 template infinite // recursion bug. template - constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) + constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const -> decltype(this->WithShape(Shape().EqualTo(shape))) { return WithShape(Shape().EqualTo(shape)); } @@ -1885,7 +1886,7 @@ class HloInstructionPattern { // Make this a templated function to work around gcc 4.9.4 template infinite // recursion bug. template - constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) + constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const -> decltype(this->WithShape(Shape().CompatibleTo(shape))) { return WithShape(Shape().CompatibleTo(shape)); } @@ -2035,7 +2036,7 @@ XLA_UNOP_PATTERN(Ceil) XLA_UNOP_PATTERN(Convert) XLA_UNOP_PATTERN(Copy) XLA_UNOP_PATTERN(Cos) -XLA_UNOP_PATTERN(CrossReplicaSum) +XLA_UNOP_PATTERN(AllReduce) XLA_UNOP_PATTERN(Exp) XLA_UNOP_PATTERN(Fft) XLA_UNOP_PATTERN(Floor) @@ -2052,11 +2053,12 @@ XLA_UNOP_PATTERN(RecvDone) XLA_UNOP_PATTERN(ReducePrecision) XLA_UNOP_PATTERN(Reshape) XLA_UNOP_PATTERN(Reverse) +XLA_UNOP_PATTERN(Rsqrt) XLA_UNOP_PATTERN(SendDone) XLA_UNOP_PATTERN(Sign) XLA_UNOP_PATTERN(Sin) XLA_UNOP_PATTERN(Slice) -XLA_UNOP_PATTERN(Sort) +XLA_UNOP_PATTERN(Sqrt) XLA_UNOP_PATTERN(Tanh) XLA_UNOP_PATTERN(Transpose) #undef XLA_UNOP_PATTERN @@ -2118,7 +2120,6 @@ XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) XLA_BINOP_PATTERN(Convolution) XLA_BINOP_PATTERN(Dot) -XLA_BINOP_PATTERN(DynamicSlice) XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) XLA_BINOP_PATTERN(Ge) @@ -2235,8 +2236,10 @@ inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg, XLA_VARIADIC_OP_PATTERN(AfterAll); XLA_VARIADIC_OP_PATTERN(Concatenate); XLA_VARIADIC_OP_PATTERN(CustomCall); +XLA_VARIADIC_OP_PATTERN(DynamicSlice) XLA_VARIADIC_OP_PATTERN(Map) XLA_VARIADIC_OP_PATTERN(Reduce); +XLA_VARIADIC_OP_PATTERN(Sort); XLA_VARIADIC_OP_PATTERN(Tuple); // Helpers for matching non-constant instructions. diff --git a/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc index 9ca2fb05c1f7ef093c58237cf21fbc7c813a592a..f51a18b13894d75300c46835fabd82a4ce0699af 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc @@ -23,7 +23,6 @@ namespace xla { namespace { namespace m = ::xla::match; -using ::testing::Eq; using ::testing::Not; template diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 186ef0c7911a2724df810780e018f52586e3e6a8..5c3c009a68bffbda8642fceedfb724879fbf1530 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -242,8 +242,8 @@ TEST(PatternMatcherTest, ConstantScalar) { HloModule test_module ENTRY test { a = s32[] constant(1) - b = s32[1,1] constant(s32[1,1]{{2}}) - c = s32[1,2] constant(s32[1,2]{{2,2}}) + b = s32[1,1] constant({{2}}) + c = s32[1,2] constant({{2,2}}) d = f32[] constant(1) e = f32[] constant(1.25) ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e) diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index c227106511c2c17b44569d3b696cd7d764226e81..886a0545624927fa77528141f61d8ecb6bec180a 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -70,6 +70,9 @@ PlatformUtil::GetSupportedPlatforms() { for (se::Platform* platform : all_platforms) { auto compiler_status = Compiler::GetForPlatform(platform); if (compiler_status.ok()) { + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } platforms.push_back(platform); } else { LOG(INFO) << "platform " << platform->Name() << " present but no " @@ -205,7 +208,9 @@ static bool IsDeviceSupported(se::StreamExecutor* executor) { } /* static */ StatusOr> -PlatformUtil::GetStreamExecutors(se::Platform* platform) { +PlatformUtil::GetStreamExecutors( + se::Platform* platform, + const absl::optional>& allowed_devices) { int device_count = platform->VisibleDeviceCount(); if (device_count <= 0) { return NotFound("no %s devices found", platform->Name()); @@ -226,6 +231,17 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { tensorflow::thread::ThreadPool thread_pool( tensorflow::Env::Default(), "device_initialization", device_count); for (int i = 0; i < device_count; ++i) { + // Once a stream executor is instantiated it will cause allocations on + // the device, for example for GPUs cuda context, cudnn handles etc. will + // be constructed. By constructing stream executors only on the + // allowed_devices, we don't make any allocations on other devices. + // This helps in multi-process executions on the same host like horovod or + // shared hosts. + if (allowed_devices && allowed_devices->count(i) == 0) { + VLOG(1) << "Not initializing StreamExecutor for device " << i + << " since it is not in the visible device list"; + continue; + } thread_pool.Schedule([platform, i, &stream_executors]() { VLOG(1) << "Started device init " << i; se::StreamExecutorConfig config; @@ -247,8 +263,8 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { // Block here in thread_pool destructor until all devices are initialized. } VLOG(1) << "Device initialization complete"; - if (std::all_of(stream_executors.begin(), stream_executors.end(), - [](se::StreamExecutor* s) { return s == nullptr; })) { + if (absl::c_all_of(stream_executors, + [](se::StreamExecutor* s) { return s == nullptr; })) { return InternalError("no supported devices found for platform %s", platform->Name()); } diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index 571451ba43a81d19b70e4954e45d3447f15dcedc..592b20282f334e12e0d7a7f683c9a6ab59d21fea 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ +#include #include #include @@ -60,10 +61,14 @@ class PlatformUtil { // Returns a vector of StreamExecutors for the given platform. The vector is // indexed by device ordinal (device numbering used by StreamExecutor). If an // element is nullptr, then the device is present by not supported by XLA. + // If populated, only the devices in allowed_devices will have + // their StreamExecutors initialized, otherwise all StreamExecutors will be + // initialized and returned. // // If the platform has no visible devices, a not-found error is returned. static StatusOr> GetStreamExecutors( - se::Platform* platform); + se::Platform* platform, + const absl::optional>& allowed_devices = absl::nullopt); private: TF_DISALLOW_COPY_AND_ASSIGN(PlatformUtil); diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 4df746fca9f8320eed72911726f33bb01f06fed5..a62118df157edf67114ff41befbdce3da129fe93 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -226,7 +226,10 @@ StatusOr PerformSinkReshapeOrTranspose( // changes, so all the fused instructions have the same dimensions. for (const auto& fused_instruction : instruction->fused_instructions()) { Shape* shape = fused_instruction->mutable_shape(); - *shape->mutable_dimensions() = new_operand_shape.dimensions(); + shape->clear_dimensions(); + for (int64 i : new_operand_shape.dimensions()) { + shape->add_dimensions(i); + } *shape->mutable_layout() = new_operand_shape.layout(); } } diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index 11c2f8392d285095816dd5d61f7029c1bfd158d4..acad871c4d427b174ffce3a462a0a3918a1e0c33 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -26,7 +26,6 @@ limitations under the License. namespace xla { - // Transposes the given scatter_indices such that the index_vector_dim becomes // the most-minor dimension. static StatusOr TransposeIndexVectorDimToLast( @@ -60,6 +59,13 @@ static StatusOr CanonicalizeScatterIndices( TF_ASSIGN_OR_RETURN( HloInstruction * transposed_scatter_indices, TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); + if (scatter_indices->shape().rank() == index_vector_dim + 1 && + scatter_indices->shape().dimensions(index_vector_dim) == 1) { + auto new_shape = + ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape()); + TF_ASSIGN_OR_RETURN(scatter_indices, + MakeReshapeHlo(new_shape, scatter_indices)); + } bool indices_are_scalar = index_vector_dim == scatter_indices->shape().dimensions_size(); @@ -88,7 +94,7 @@ static StatusOr CanonicalizeScatterIndices( static StatusOr PermuteScatterAndWindowDims( HloInstruction* updates, absl::Span update_window_dims) { std::vector permutation; - const int64 updates_rank = ShapeUtil::Rank(updates->shape()); + const int64 updates_rank = updates->shape().rank(); permutation.reserve(updates_rank); for (int64 i = 0; i < updates_rank; ++i) { @@ -165,10 +171,9 @@ static StatusOr CheckIndexValidity( // Valid range for the index: [0, operand_dims - window_sizes] // Check if the index has any negative values. - TF_ASSIGN_OR_RETURN( - HloInstruction * zero_index, + HloInstruction* zero_index = BroadcastZeros(computation, index->shape().element_type(), - AsInt64Slice(index->shape().dimensions()))); + AsInt64Slice(index->shape().dimensions())); TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check, MakeBinaryHlo(HloOpcode::kLe, zero_index, index)); @@ -214,15 +219,11 @@ static StatusOr> ScatterLoopBody( HloInstruction* updates = loop_state[2]; bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1; - CHECK_EQ(has_scalar_indices, - dim_numbers.index_vector_dim() == - scatter->operand(1)->shape().dimensions_size()); // Build a vector form of the induction variable of the while loop. - TF_ASSIGN_OR_RETURN( - HloInstruction * induction_var_as_vector, + HloInstruction* induction_var_as_vector = MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, - /*result_shape_bounds=*/{1})); + /*result_shape_bounds=*/{1}); // Pick the index to scatter from scatter_indices based on the induction_var // and transform that to an index into the `operand` space. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 5ec7fe2adedac2fc3d8a7588e853dba90e99006f..9bda6fba3aabfed78ae724545387e86bad36c886 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" @@ -113,6 +114,16 @@ int ServiceOptions::intra_op_parallelism_threads() const { return intra_op_parallelism_threads_; } +ServiceOptions& ServiceOptions::set_allowed_devices( + const absl::optional>& allowed_devices) { + allowed_devices_ = allowed_devices; + return *this; +} + +const absl::optional>& ServiceOptions::allowed_devices() const { + return allowed_devices_; +} + /* static */ StatusOr> Service::NewService( se::Platform* platform) { ServiceOptions default_options; @@ -129,6 +140,7 @@ int ServiceOptions::intra_op_parallelism_threads() const { } BackendOptions backend_options; backend_options.set_platform(platform); + backend_options.set_allowed_devices(options.allowed_devices()); TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options)); std::unique_ptr service( @@ -150,17 +162,13 @@ Service::Service(const ServiceOptions& options, LOG(INFO) << StrFormat( "XLA service %p executing computations on platform %s. Devices:", this, execute_backend_->platform()->Name()); + auto stream_executors = execute_backend_->stream_executors(); for (int i = 0; i < execute_backend_->device_count(); ++i) { - if (execute_backend_->device_ordinal_supported(i)) { - se::StreamExecutor* executor = - execute_backend_->stream_executor(i).ValueOrDie(); - const auto& description = executor->GetDeviceDescription(); - LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i, - description.name(), - description.platform_version()); - } else { - LOG(INFO) << StrFormat(" StreamExecutor device (%d) not supported", i); - } + se::StreamExecutor* executor = stream_executors.at(i); + const auto& description = executor->GetDeviceDescription(); + LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i, + description.name(), + description.platform_version()); } } else { VLOG(1) << "XLA compile-only service constructed"; @@ -288,11 +296,16 @@ StatusOr> Service::CreateModuleConfig( computation_layout->mutable_result_layout()->SetToDefaultLayout(); } - config->set_replica_count(options_.number_of_replicas()); if (execution_options != nullptr) { + if (execution_options->num_replicas() > 0) { + config->set_replica_count(execution_options->num_replicas()); + } else { + config->set_replica_count(options_.number_of_replicas()); + } config->set_seed(execution_options->seed()); config->set_debug_options(execution_options->debug_options()); } else { + config->set_replica_count(options_.number_of_replicas()); config->set_debug_options(GetDebugOptionsFromFlags()); } @@ -355,6 +368,7 @@ StatusOr>> Service::BuildExecutables( const HloModuleProto* proto = module_protos[i]; const HloModuleConfig& config = *module_configs[i]; TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module)); module_group->push_back(std::move(module)); } @@ -516,13 +530,13 @@ Service::ExecuteParallelAndRegisterResult( StatusOr Service::ExecuteAndRegisterResult( Executable* executable, - const absl::Span> arguments, - Backend* backend, const string& result_tag, ExecutionProfile* profile) { + absl::Span> arguments, + Backend* backend, const DeviceHandle& device_handle, + const string& result_tag, ExecutionProfile* profile) { // Set up streams. std::vector streams; - TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*backend, SingleComputationDeviceHandle())); + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handle)); TF_RET_CHECK(!replicas.empty()); for (se::StreamExecutor* executor : replicas) { TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream, @@ -530,10 +544,11 @@ StatusOr Service::ExecuteAndRegisterResult( streams.push_back(std::move(stream)); } - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - backend->computation_placer()->AssignDevices( - options_.number_of_replicas(), - /*computation_count=*/1)); + DeviceAssignment device_assignment(options_.number_of_replicas(), + /*computation_count=*/1); + for (int64 replica = 0; replica < replicas.size(); ++replica) { + device_assignment(replica, 0) = replicas[replica]->device_ordinal(); + } // Set up run options. std::vector run_options; @@ -545,9 +560,7 @@ StatusOr Service::ExecuteAndRegisterResult( options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); - run_options.emplace_back( - options, backend->StreamBorrower(), - /*xla_intra_op_thread_pool=*/backend->eigen_intra_op_thread_pool()); + run_options.emplace_back(options, backend->StreamBorrower()); } if (options_.number_of_replicas() == 1) { @@ -704,14 +717,33 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, } } - // Execute the generated executables in parallel and return the device - // handles for each computation's output. + // If we have multiple executables to run, execute them all in parallel. But + // if we only have one executable, execute it using the vanilla, non-parallel + // call. + // + // We do this because the Client API uses ExecuteGraphParallel when it wants + // to compile and run one computation without caching the executable, but not + // all backends support the async StreamExecutor API required by + // ExecuteParallelAndRegisterResult. + // + // TODO(b/122731460): Consolidate Execute{,Parallel}AndRegisterResult; they do + // basically the same thing. ExecutionProfile profile; - TF_ASSIGN_OR_RETURN( - std::vector outputs, - ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, - execute_backend_.get(), device_handles, - computation_names, &profile)); + std::vector outputs; + if (executable_ptrs.size() == 1) { + TF_ASSIGN_OR_RETURN( + auto output, + ExecuteAndRegisterResult(executable_ptrs[0], all_arguments[0], + execute_backend_.get(), device_handles[0], + computation_names[0], &profile)); + outputs.push_back(std::move(output)); + } else { + TF_ASSIGN_OR_RETURN( + outputs, ExecuteParallelAndRegisterResult( + executable_ptrs, all_arguments, execute_backend_.get(), + device_handles, computation_names, &profile)); + } + for (const GlobalDataHandle& output : outputs) { ExecuteResponse response; *response.mutable_output() = output; @@ -897,6 +929,7 @@ Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { *result->mutable_output(), ExecuteAndRegisterResult(executable.get(), replicated_arguments, execute_backend_.get(), + SingleComputationDeviceHandle(), "result of " + executable->module().name(), result->mutable_profile())); @@ -1078,9 +1111,11 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ProgramShape program_shape(arg->computation().host_program_shape()); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); + absl::optional output_layout; if (arg->has_output_layout()) { + output_layout = Layout::CreateFromProto(arg->output_layout()); TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( - arg->output_layout(), program_shape.result())); + *output_layout, program_shape.result())); } HloModuleConfig config(program_shape); @@ -1088,16 +1123,19 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, TF_ASSIGN_OR_RETURN(std::unique_ptr module, CreateModuleFromProto(arg->computation(), config)); + TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(module.get())); + HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate( - *module, /*arg_literals=*/{})); + evaluator.set_dynamic_dimension_inference(&dynamic_dimension_inference); + TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {})); // Since the result layout is non-effective to the Evaluator results, explicit // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (arg->has_output_layout()) { - result_literal = result_literal.Relayout(arg->output_layout()); + if (output_layout.has_value()) { + result_literal = result_literal.Relayout(*output_layout); } *result->mutable_literal() = result_literal.ToProto(); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 11e1a79552fbd944ab28da129b08cfe676fb08e9..fd907d07daef9e8337aeed198ef4fd23d069df21 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include @@ -52,7 +53,7 @@ class ServiceOptions { ServiceOptions& set_platform(se::Platform* platform); se::Platform* platform() const; - // Set the number of replicas to use when compiling replicated + // Set the default number of replicas to use when compiling replicated // programs. ServiceOptions& set_number_of_replicas(int number_of_replicas); int number_of_replicas() const; @@ -61,10 +62,17 @@ class ServiceOptions { ServiceOptions& set_intra_op_parallelism_threads(int num_threads); int intra_op_parallelism_threads() const; + // Sets the allowed_devices set for selectively constructing stream executors + // on the platform. + ServiceOptions& set_allowed_devices( + const absl::optional>& allowed_devices); + const absl::optional>& allowed_devices() const; + private: se::Platform* platform_ = nullptr; int number_of_replicas_ = 1; int intra_op_parallelism_threads_ = -1; + absl::optional> allowed_devices_; }; // The XLA service object, which is the same across all platforms. It maintains @@ -242,8 +250,9 @@ class Service : public ServiceInterface { // ExecutionProfile object which will be filled in with profile data. StatusOr ExecuteAndRegisterResult( Executable* executable, - const absl::Span> arguments, - Backend* backend, const string& result_tag, ExecutionProfile* profile); + absl::Span> arguments, + Backend* backend, const DeviceHandle& device_handle, + const string& result_tag, ExecutionProfile* profile); // Runs the given executables with the given arguments and register the result // from each executable in the allocation tracker. The handles of the result diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index dbfed628bfcabffe66bef41a82e0e2430897d80d..6bee671056552b83014367889320b748659bbfdf 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -32,12 +32,10 @@ class ServiceExecutableRunOptions { ServiceExecutableRunOptions() : ServiceExecutableRunOptions(ExecutableRunOptions()) {} - explicit ServiceExecutableRunOptions( - ExecutableRunOptions run_options, StreamBorrower borrow_stream = nullptr, - tensorflow::thread::ThreadPool* xla_intra_op_thread_pool = nullptr) + explicit ServiceExecutableRunOptions(ExecutableRunOptions run_options, + StreamBorrower borrow_stream = nullptr) : run_options_(std::move(run_options)), - borrow_stream_(std::move(borrow_stream)), - xla_intra_op_thread_pool_(xla_intra_op_thread_pool) {} + borrow_stream_(std::move(borrow_stream)) {} // Returns reference or pointer to `ExecutableRunOptions` member. const ExecutableRunOptions& run_options() const { return run_options_; } @@ -56,15 +54,9 @@ class ServiceExecutableRunOptions { : Status(tensorflow::error::UNIMPLEMENTED, "No stream cache"); } - // Returns reference to thread pool for execution of XLA ops on CPU backend. - tensorflow::thread::ThreadPool* xla_intra_op_thread_pool() const { - return xla_intra_op_thread_pool_; - } - private: ExecutableRunOptions run_options_; StreamBorrower borrow_stream_; - tensorflow::thread::ThreadPool* xla_intra_op_thread_pool_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 7e7282a737041458aed39b0054f901c23aa87d7a..431c2e3a5e0dac3093ba39640f3451bec6911f9f 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shape_inference.h" -#include #include +#include #include #include #include @@ -50,7 +50,7 @@ bool AllUnique(absl::Span slice) { } Status ExpectArray(const Shape& shape, absl::string_view op_type) { - if (!ShapeUtil::IsArray(shape)) { + if (!shape.IsArray()) { return InvalidArgument("Expected array argument for %s, but got %s.", string(op_type), ShapeUtil::HumanString(shape)); } @@ -70,7 +70,7 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, const Shape& accumulator_shape = reducer_shape.result(); std::vector accumulator_subshapes; - if (ShapeUtil::IsArray(accumulator_shape)) { + if (accumulator_shape.IsArray()) { if (inputs != 1) { return InvalidArgument( "Reduction function must produce a tuple with %d elements, but " @@ -78,7 +78,7 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, inputs); } accumulator_subshapes.push_back(&accumulator_shape); - } else if (ShapeUtil::IsTuple(accumulator_shape)) { + } else if (accumulator_shape.IsTuple()) { if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) { return InvalidArgument( "Reduction function must produce a tuple with %d elements, but has " @@ -96,7 +96,7 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, } for (const Shape* element_shape : accumulator_subshapes) { - if (ShapeUtil::Rank(*element_shape) != 0) { + if (element_shape->rank() != 0) { return InvalidArgument( "Reduction function must return a scalar or tuple of scalars but " "returns shape: %s", @@ -156,17 +156,26 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, return Status::OK(); } +bool IsTrivialWindowDimension(const WindowDimension& window_dimension) { + return window_dimension.size() == 1 && window_dimension.stride() == 1 && + window_dimension.padding_low() == 0 && + window_dimension.padding_high() == 0 && + window_dimension.window_dilation() == 1 && + window_dimension.base_dilation() == 1; +} + StatusOr InferWindowOutputShape(const Shape& base_shape, const Window& window, PrimitiveType element_type, bool allow_negative_padding) { - if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) { + if (window.dimensions_size() != base_shape.rank()) { return InvalidArgument( "Window has dimension %d but base shape has dimension %d.", - window.dimensions_size(), ShapeUtil::Rank(base_shape)); + window.dimensions_size(), base_shape.rank()); } std::vector output_dimensions(window.dimensions_size()); + std::vector output_is_dynamic(window.dimensions_size()); for (int64 i = 0; i < window.dimensions_size(); ++i) { const auto& dim = window.dimensions(i); if (dim.size() <= 0) { @@ -196,6 +205,12 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, window.DebugString()); } + if (base_shape.is_dynamic_dimension(i) && !IsTrivialWindowDimension(dim)) { + return Unimplemented( + "Dynamic shape is not supported for non trivial window: %s", + window_util::ToString(window)); + } + const int64 dilated_base = window_util::DilatedBound( ShapeUtil::GetDimension(base_shape, i), dim.base_dilation()); const int64 padded_dilated_base = @@ -205,9 +220,11 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, output_dimensions[i] = window_util::StridedBound( padded_dilated_base, dilated_window, dim.stride()); + output_is_dynamic[i] = base_shape.is_dynamic_dimension(i); } - return ShapeUtil::MakeValidatedShape(element_type, output_dimensions); + return ShapeUtil::MakeValidatedShape(element_type, output_dimensions, + output_is_dynamic); } } // namespace @@ -245,6 +262,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case HloOpcode::kExpm1: case HloOpcode::kLog: case HloOpcode::kLog1p: + case HloOpcode::kRsqrt: + case HloOpcode::kSqrt: case HloOpcode::kTanh: if (!ShapeUtil::ElementIsFloating(shape) && !ShapeUtil::ElementIsComplex(shape)) { @@ -338,7 +357,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (arg_shapes.empty()) { return InvalidArgument("Concatenate expects at least one argument."); } - if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) { + if (dimension < 0 || dimension >= arg_shapes[0]->rank()) { return InvalidArgument("Concatenate dimension out of bounds: %d.", dimension); } @@ -351,12 +370,12 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, element_type = arg_shape->element_type(); continue; } - if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { + if (arg_shape->rank() != shape->rank()) { return InvalidArgument( "Cannot concatenate arrays with different ranks: %d (%s) vs %d " "(%s).", - ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape), - ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape)); + arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(), + ShapeUtil::HumanString(*shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( @@ -364,8 +383,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, PrimitiveType_Name(arg_shape->element_type()), PrimitiveType_Name(shape->element_type())); } - for (int64 dimension_number = 0; - dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) { + for (int64 dimension_number = 0; dimension_number < arg_shape->rank(); + ++dimension_number) { if (arg_shape->dimensions(dimension_number) != shape->dimensions(dimension_number)) { if (dimension_number == dimension) { @@ -401,7 +420,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape), PrimitiveType_Name(new_element_type)); } - if (!ShapeUtil::IsArray(operand_shape) || + if (!operand_shape.IsArray() || !primitive_util::IsArrayType(new_element_type)) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions @@ -424,7 +443,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape), PrimitiveType_Name(new_element_type)); } - if (!ShapeUtil::IsArray(operand_shape) || + if (!operand_shape.IsArray() || !primitive_util::IsArrayType(new_element_type)) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions @@ -472,7 +491,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferPadShape( const Shape& operand_shape, const Shape& padding_value_shape, const PaddingConfig& padding_config) { - if (!ShapeUtil::IsArray(operand_shape)) { + if (!operand_shape.IsArray()) { return InvalidArgument( "Pad operation does not support tuple-shape operands."); } @@ -480,7 +499,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Pad operation does not support non-scalar padding values."); } - if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) { + if (operand_shape.rank() != padding_config.dimensions_size()) { return InvalidArgument( "The rank of the operand and the padding configuration do not match: " "%s vs %s.", @@ -500,35 +519,44 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, padding_config.ShortDebugString()); } - std::vector dimensions(ShapeUtil::Rank(operand_shape)); + if (!padding_value_shape.is_static()) { + return InvalidArgument("Dynamic padding value is not supported"); + } + + std::vector dimensions(operand_shape.rank()); + std::vector is_dynamic(operand_shape.rank()); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { const auto& p = padding_config.dimensions(i); + if (operand_shape.is_dynamic_dimension(i) && p.edge_padding_high() != 0 && + p.edge_padding_low() != 0 && p.interior_padding() != 0) { + return InvalidArgument( + "Dynamic dimension on padding dimension is not supported."); + } dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() + p.edge_padding_high() + std::max(operand_shape.dimensions(i) - 1, 0LL) * p.interior_padding(); + if (dimensions[i] < 0) { + return InvalidArgument("Padding result in negative size for dimension %d", + i); + } + is_dynamic[i] = operand_shape.is_dynamic_dimension(i); } + return ShapeUtil::MakeShape( ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape), - dimensions); + dimensions, is_dynamic); } // Current DotDimensionNumbers Requirements: // // Contracting Dimensions: -// *) Exactly one contracting dimension on both lhs and rhs. +// *) Same number of contracting dimensions on both lhs and rhs. // *) Contracting dimension size must be the same on both lhs and rhs. -// *) Contracting dimension numbers do not need to be the same (i.e. transposes -// are passed on to emitter implementations). // // Batch Dimensions: // *) Same number of batch dimensions on both lhs and rhs. -// *) Same batch dimension numbers (and sizes) on both lhs and rhs. -// *) Batch dimension numbers must be ordered before contracting and -// non-contracting/non-batch dimension numbers. -// -// Non-Contracting-Non-Batch Dimensions: -// *) Can be 0 (matrix-vector) or 1 (matrix-matrix). +// *) Same batch dimension sizes on both lhs and rhs. // namespace { @@ -541,9 +569,8 @@ Status ValidateDotDimensionNumbers( absl::Span contracting_dims, absl::Span batch_dims) -> bool { auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; - return std::all_of(contracting_dims.begin(), contracting_dims.end(), - in_range) && - std::all_of(batch_dims.begin(), batch_dims.end(), in_range); + return absl::c_all_of(contracting_dims, in_range) && + absl::c_all_of(batch_dims, in_range); }; absl::Span lhs_contracting_dimensions = @@ -555,9 +582,9 @@ Status ValidateDotDimensionNumbers( absl::Span rhs_batch_dimensions = AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); - if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, + if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions, lhs_batch_dimensions) || - !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, + !dims_in_range(rhs.rank(), rhs_contracting_dimensions, rhs_batch_dimensions)) { return InvalidArgument("A dimension number is out of range in Dot: %s.", dimension_numbers.DebugString()); @@ -570,9 +597,8 @@ Status ValidateDotDimensionNumbers( auto is_unique = [&dim_set](int64 i) -> bool { return dim_set.insert(i).second; }; - return std::all_of(contracting_dims.begin(), contracting_dims.end(), - is_unique) && - std::all_of(batch_dims.begin(), batch_dims.end(), is_unique); + return absl::c_all_of(contracting_dims, is_unique) && + absl::c_all_of(batch_dims, is_unique); }; if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || @@ -581,36 +607,6 @@ Status ValidateDotDimensionNumbers( dimension_numbers.DebugString()); } - // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. - const int64 lhs_non_contracting_non_batch_dims = - ShapeUtil::Rank(lhs) - - dimension_numbers.lhs_contracting_dimensions_size() - - dimension_numbers.lhs_batch_dimensions_size(); - const int64 rhs_non_contracting_non_batch_dims = - ShapeUtil::Rank(rhs) - - dimension_numbers.rhs_contracting_dimensions_size() - - dimension_numbers.rhs_batch_dimensions_size(); - if (lhs_non_contracting_non_batch_dims < 0 || - lhs_non_contracting_non_batch_dims > 1 || - rhs_non_contracting_non_batch_dims < 0 || - rhs_non_contracting_non_batch_dims > 1) { - return InvalidArgument( - "Batch and contracting dimension number mismatch with rank."); - } - - // Check that batch dimension numbers are ordered before all others, and - // that they are monotonically increasing. - std::vector batch_dim_numbers(lhs_batch_dimensions.size()); - std::iota(batch_dim_numbers.begin(), batch_dim_numbers.end(), 0); - if (!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), - lhs_batch_dimensions.begin()) || - !std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), - rhs_batch_dimensions.begin())) { - return InvalidArgument( - "Batch dimension numbers must precede non-batch dimensions and be" - "monotonically increasing."); - } - return Status::OK(); } @@ -637,28 +633,33 @@ Status ValidateDotDimensionNumbers( return fail("Element types do not match."); } - if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) { + if ((lhs.rank() < 1) || (rhs.rank() < 1)) { return fail("Dot only supports rank 1 or above."); } // Validate basic properties of dot dimension numbers. TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); - // Check that there is only one contracting dimension for both lhs and rhs. + // Check that number of contracting dimensions match. if (dimension_numbers.lhs_contracting_dimensions_size() != - dimension_numbers.rhs_contracting_dimensions_size() || - dimension_numbers.lhs_contracting_dimensions_size() != 1) { - return fail("Must specify one contracting dimension for both lhs and rhs."); + dimension_numbers.rhs_contracting_dimensions_size()) { + return fail( + "Must specify the same number of contracting dimensions for lhs and " + "rhs."); } - // Check that contracting dimension sizes match. - const int64 lhs_contracting_dimension = - dimension_numbers.lhs_contracting_dimensions(0); - const int64 rhs_contracting_dimension = - dimension_numbers.rhs_contracting_dimensions(0); - if (lhs.dimensions(lhs_contracting_dimension) != - rhs.dimensions(rhs_contracting_dimension)) { - return fail("Contracting dimension sizes do not match."); + for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size(); + ++i) { + const int64 lhs_contracting_dimension = + dimension_numbers.lhs_contracting_dimensions(i); + const int64 rhs_contracting_dimension = + dimension_numbers.rhs_contracting_dimensions(i); + if (lhs.dimensions(lhs_contracting_dimension) != + rhs.dimensions(rhs_contracting_dimension) || + lhs.is_dynamic_dimension(lhs_contracting_dimension) != + rhs.is_dynamic_dimension(rhs_contracting_dimension)) { + return fail("Contracting dimension sizes do not match."); + } } // Check that number of batch dimensions match. @@ -669,11 +670,12 @@ Status ValidateDotDimensionNumbers( // Check that batch dimension numbers and sizes match. for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) { - if (dimension_numbers.lhs_batch_dimensions(i) != - dimension_numbers.rhs_batch_dimensions(i) || - lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != - rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { - return fail("Batch dimension numbers and sizes must match for lhs/rhs."); + if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != + rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i)) || + lhs.is_dynamic_dimension(dimension_numbers.lhs_batch_dimensions(i)) != + rhs.is_dynamic_dimension( + dimension_numbers.rhs_batch_dimensions(i))) { + return fail("Batch dimension sizes must match for lhs/rhs."); } } @@ -683,21 +685,29 @@ Status ValidateDotDimensionNumbers( // Generate the result dimensions in order, rhs dimensions followed by lhs // dimensions except the contracted and batch dimensions. std::vector dimensions; - std::unordered_set rhs_batch_dims( - dimension_numbers.rhs_batch_dimensions().begin(), - dimension_numbers.rhs_batch_dimensions().end()); - for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) { - if (i != lhs_contracting_dimension) { + std::vector is_dynamic; + for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) { + dimensions.push_back(lhs.dimensions(lhs_dim)); + is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim)); + } + for (int64 i = 0; i < lhs.rank(); i++) { + if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(), + i) && + !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) { dimensions.push_back(lhs.dimensions(i)); + is_dynamic.push_back(lhs.is_dynamic_dimension(i)); } } - for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) { - if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) { + for (int64 i = 0; i < rhs.rank(); i++) { + if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(), + i) && + !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) { dimensions.push_back(rhs.dimensions(i)); + is_dynamic.push_back(rhs.is_dynamic_dimension(i)); } } Shape result = ShapeUtil::MakeShape( - ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions); + ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions, is_dynamic); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result)); VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result); @@ -708,20 +718,24 @@ Status ValidateDotDimensionNumbers( ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& lhs, const Shape& rhs) { - TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)); + TF_RET_CHECK(lhs.rank() == rhs.rank()); // The shapes have to be compatible. That is, if some dimension d has a // different size in the two shapes, one of them has to be 1 (a "degenerate" // dimension). In that case, the output shape has the non-1 dimension size // from the lhs/rhs pair in every index. - std::vector output_dimensions(ShapeUtil::Rank(lhs)); - for (int64 i = 0; i < ShapeUtil::Rank(lhs); ++i) { + std::vector output_dimensions(lhs.rank()); + std::vector output_dimensions_is_dynamic(lhs.rank()); + for (int64 i = 0; i < lhs.rank(); ++i) { if (lhs.dimensions(i) == rhs.dimensions(i)) { output_dimensions[i] = lhs.dimensions(i); + output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i); } else if (lhs.dimensions(i) == 1) { output_dimensions[i] = rhs.dimensions(i); + output_dimensions_is_dynamic[i] = rhs.is_dynamic_dimension(i); } else if (rhs.dimensions(i) == 1) { output_dimensions[i] = lhs.dimensions(i); + output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i); } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", @@ -730,7 +744,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), - output_dimensions); + output_dimensions, output_dimensions_is_dynamic); } /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( @@ -743,13 +757,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument("Automatic shape inference not supported: %s and %s", ShapeUtil::HumanString(smaller_shape), ShapeUtil::HumanString(larger_shape)); - } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) { + } else if (broadcast_dimensions.size() != smaller_shape.rank()) { return InvalidArgument( "Size of broadcast_dimensions has to match lower-rank operand's " "rank; " " lower-rank operand's rank is %d, size of broadcast_dimensions is " "%u.", - ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size()); + smaller_shape.rank(), broadcast_dimensions.size()); } // broadcast_dimensions is a sequence of dimensions; its length is equal to @@ -809,6 +823,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } int64 small_dimension_size = smaller_shape.dimensions(i); int64 large_dimension_size = larger_shape.dimensions(dimension_to_match); + bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i); + bool large_is_dynamic = + larger_shape.is_dynamic_dimension(dimension_to_match); // Dimension sizes must be compatible: match or be degenerate (degenerate // case is handled by degenerate dimension broadcasting which occurs after // InDim broadcasting). @@ -820,6 +837,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(smaller_shape), ShapeUtil::HumanString(larger_shape)); } + if (small_is_dynamic != large_is_dynamic) { + if (small_dimension_size == large_dimension_size || + (small_dimension_size == 1 && !small_is_dynamic) || + (large_dimension_size == 1 && !large_is_dynamic)) { + // Do nothing. It's OK when the size-1 dimension is not static. + } else { + return InvalidArgument( + "Broadcast dimension %d dynamism mismatch: %s and %s.", i, + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape)); + } + } // Make sure the broadcast dimensions are listed in a strictly increasing // order. if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) { @@ -829,6 +858,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } output_shape.set_dimensions(dimension_to_match, small_dimension_size); + output_shape.set_dynamic_dimension(dimension_to_match, small_is_dynamic); } return output_shape; @@ -847,8 +877,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(rhs)); } - if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { - std::vector identity_dims(ShapeUtil::Rank(lhs)); + if (lhs.rank() == rhs.rank()) { + std::vector identity_dims(lhs.rank()); std::iota(identity_dims.begin(), identity_dims.end(), 0); if (!broadcast_dimensions.empty() && broadcast_dimensions != identity_dims) { @@ -865,15 +895,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs)); } - if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { + if (lhs.rank() == rhs.rank()) { return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs); } else { // Ranks do not match, so perform InDim broadcasting using // broadcast_dimensions. Scalar broadcasting is a special case of this. - const Shape& larger_shape = - ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs; - const Shape& smaller_shape = - ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs; + const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs; + const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs; // After InDim broadcasting, perform degenerate dimensions broadcasting. TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape, @@ -942,6 +970,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, broadcast_dimensions)); if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); + } else if (lhs.element_type() == F64 && rhs.element_type() == F64) { + return ShapeUtil::ChangeElementType(shape, C128); } else { return Unimplemented("Complex component type is not implemented."); } @@ -1162,12 +1192,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == Status::OK()); - if (feature_index >= ShapeUtil::Rank(operand_shape)) { + if (feature_index >= operand_shape.rank()) { return InvalidArgument( "Expected feature_index of batch-norm-training to be " "smaller than the rank of operand_shape; " "got feature_index %d, and rank %d.", - feature_index, ShapeUtil::Rank(operand_shape)); + feature_index, operand_shape.rank()); } if (feature_index < 0) { @@ -1177,25 +1207,25 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, feature_index); } - if (ShapeUtil::Rank(operand_shape) < 1) { + if (operand_shape.rank() < 1) { return InvalidArgument( "Expected the rank of operand to " "batch-norm-training to be at least 1; got %d.", - ShapeUtil::Rank(operand_shape)); + operand_shape.rank()); } - if (ShapeUtil::Rank(offset_shape) != 1) { + if (offset_shape.rank() != 1) { return InvalidArgument( "Offset input of batch-norm-training must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(offset_shape)); + offset_shape.rank()); } - if (ShapeUtil::Rank(scale_shape) != 1) { + if (scale_shape.rank() != 1) { return InvalidArgument( "Scale input of batch-norm-training must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(scale_shape)); + scale_shape.rank()); } if (!ShapeUtil::ElementIsFloating(operand_shape)) { @@ -1272,12 +1302,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) == Status::OK()); - if (feature_index >= ShapeUtil::Rank(operand_shape)) { + if (feature_index >= operand_shape.rank()) { return InvalidArgument( "Expected feature_index of batch-norm-inference to be " "smaller than the rank of operand_shape; " "got feature_index %d, and rank %d.", - feature_index, ShapeUtil::Rank(operand_shape)); + feature_index, operand_shape.rank()); } if (feature_index < 0) { @@ -1287,25 +1317,25 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, feature_index); } - if (ShapeUtil::Rank(operand_shape) < 1) { + if (operand_shape.rank() < 1) { return InvalidArgument( "Expected the rank of operand to " "batch-norm-inference to be at least 1; got %d.", - ShapeUtil::Rank(operand_shape)); + operand_shape.rank()); } - if (ShapeUtil::Rank(offset_shape) != 1) { + if (offset_shape.rank() != 1) { return InvalidArgument( "Offset input of batch-norm-inference must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(offset_shape)); + offset_shape.rank()); } - if (ShapeUtil::Rank(scale_shape) != 1) { + if (scale_shape.rank() != 1) { return InvalidArgument( "Scale input of batch-norm-inference must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(scale_shape)); + scale_shape.rank()); } if (!ShapeUtil::ElementIsFloating(operand_shape)) { @@ -1417,41 +1447,41 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape)); - if (feature_index >= ShapeUtil::Rank(operand_shape)) { + if (feature_index >= operand_shape.rank()) { return InvalidArgument( "Expected feature_index of batch-norm-grad to be " "smaller than the rank of operand_shape; " "got feature_index %d, and rank %d.", - feature_index, ShapeUtil::Rank(operand_shape)); + feature_index, operand_shape.rank()); } - if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) { + if (operand_shape.rank() != output_grad_shape.rank()) { return InvalidArgument( "Expected operand_shape of batch-norm-grad to have the same rank as" " output_grad_shape; got rank(oprand_shape) %d, and" " rank(output_grad_shape) %d.", - ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape)); + operand_shape.rank(), output_grad_shape.rank()); } - if (ShapeUtil::Rank(mean_shape) != 1) { + if (mean_shape.rank() != 1) { return InvalidArgument( "Mean input of batch-norm-grad must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(mean_shape)); + mean_shape.rank()); } - if (ShapeUtil::Rank(scale_shape) != 1) { + if (scale_shape.rank() != 1) { return InvalidArgument( "Scale input of batch-norm-grad must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(scale_shape)); + scale_shape.rank()); } - if (ShapeUtil::Rank(var_shape) != 1) { + if (var_shape.rank() != 1) { return InvalidArgument( "Var input of batch-norm-grad must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(var_shape)); + var_shape.rank()); } if (!ShapeUtil::ElementIsFloating(operand_shape)) { @@ -1538,7 +1568,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } // Verify operand_shape and output_grad_shape have same bounds. - for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + for (int64 i = 0; i < operand_shape.rank(); ++i) { if (ShapeUtil::GetDimension(operand_shape, i) != ShapeUtil::GetDimension(output_grad_shape, i)) { return InvalidArgument( @@ -1556,7 +1586,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, int64 feature_group_count, - const Window& window, const ConvolutionDimensionNumbers& dnums) { + int64 batch_group_count, const Window& window, + const ConvolutionDimensionNumbers& dnums) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); @@ -1565,6 +1596,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "feature_group_count must be a positive number, got %d", feature_group_count); } + + if (batch_group_count <= 0) { + return InvalidArgument( + "batch_group_count must be a positive number, got %d", + batch_group_count); + } + + if (batch_group_count > 1 && feature_group_count > 1) { + return InvalidArgument( + "both batch_group_count %d and feature_group_count %d cannot be " + "greater than 1", + batch_group_count, feature_group_count); + } + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s.", @@ -1595,12 +1640,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } const int num_dims = num_spatial_dims + 2; - if (ShapeUtil::Rank(lhs) != num_dims) { + if (lhs.rank() != num_dims) { return InvalidArgument( "The LHS argument to a convolution should have rank %d; lhs: %s.", num_dims, ShapeUtil::HumanString(lhs)); } - if (ShapeUtil::Rank(rhs) != num_dims) { + if (rhs.rank() != num_dims) { return InvalidArgument( "The RHS argument to a convolution should have rank %d; rhs: %s.", num_dims, ShapeUtil::HumanString(rhs)); @@ -1615,29 +1660,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, input_dnums[1] = dnums.input_feature_dimension(); std::copy(dnums.input_spatial_dimensions().begin(), dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2); - std::sort(input_dnums.begin(), input_dnums.end()); + absl::c_sort(input_dnums); std::vector window_dnums(num_dims); window_dnums[0] = dnums.kernel_input_feature_dimension(); window_dnums[1] = dnums.kernel_output_feature_dimension(); std::copy(dnums.kernel_spatial_dimensions().begin(), dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2); - std::sort(window_dnums.begin(), window_dnums.end()); + absl::c_sort(window_dnums); std::vector output_dnums(num_dims); output_dnums[0] = dnums.output_batch_dimension(); output_dnums[1] = dnums.output_feature_dimension(); std::copy(dnums.output_spatial_dimensions().begin(), dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2); - std::sort(output_dnums.begin(), output_dnums.end()); + absl::c_sort(output_dnums); std::vector expected_dnums(num_dims); std::iota(expected_dnums.begin(), expected_dnums.end(), 0); const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; }; - if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) || - !std::all_of(window_dnums.begin(), window_dnums.end(), in_range) || - !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { + if (!absl::c_all_of(input_dnums, in_range) || + !absl::c_all_of(window_dnums, in_range) || + !absl::c_all_of(output_dnums, in_range)) { return InvalidArgument( "A dimension number is out of range in convolution: %s.", dnums.DebugString()); @@ -1678,6 +1723,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 kernel_output_features = rhs.dimensions(dnums.kernel_output_feature_dimension()); + if (batch_group_count > 1 && input_batch % kernel_output_features != 0) { + return InvalidArgument( + "Expected output feature dimension (value %d) to be divisible by " + "input_batch (value %d) for batch group count %d; " + "got (%s, %s)\n" + "Dimension numbers: {%s}.", + kernel_output_features, input_batch, batch_group_count, + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), + dnums.DebugString()); + } + if (input_features % feature_group_count != 0 || input_features / feature_group_count != kernel_input_features) { return InvalidArgument( @@ -1700,6 +1756,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), dnums.DebugString()); } + + if (input_batch % batch_group_count > 0) { + return InvalidArgument( + "Expected input batch dimension (value %d) to be divisible by " + "batch_group_count (value %d); " + "got (%s, %s)\n" + "Dimension numbers: {%s}.", + input_batch, batch_group_count, ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs), dnums.DebugString()); + } + std::vector window_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { window_dims[i] = window.dimensions(i).size(); @@ -1722,14 +1789,39 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /*allow_negative_padding=*/true)); std::vector dimensions(num_dims); - dimensions[dnums.output_batch_dimension()] = input_batch; + dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count; dimensions[dnums.output_feature_dimension()] = kernel_output_features; for (int i = 0; i < num_spatial_dims; ++i) { dimensions[dnums.output_spatial_dimensions(i)] = window_output_shape.dimensions(i); } + std::vector is_dynamic(num_dims); + for (int i = 0; i < num_dims; i++) { + if (lhs.is_dynamic_dimension(i)) { + if (i == dnums.input_batch_dimension()) { + is_dynamic[dnums.output_batch_dimension()] = true; + } else if (i == dnums.input_feature_dimension()) { + // Input feature dimension is a contracting dimension, which does not + // affect the output dimension size. So we need to do nothing. + } else { + return InvalidArgument( + "Dynamic Spatial Convolution is not supported: lhs shape is %s ", + lhs.ToString()); + } + } + if (rhs.is_dynamic_dimension(i)) { + if (i == dnums.kernel_input_feature_dimension()) { + // Kernel feature dimension does not affect the output dimension size. + // So we need to do nothing. + } else { + return InvalidArgument( + "Dynamic Spatial Convolution is not supported: rhs shape is %s ", + rhs.ToString()); + } + } + } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), - dimensions); + dimensions, is_dynamic); } /* static */ StatusOr ShapeInference::InferFftShape( @@ -1750,7 +1842,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case FFT: case IFFT: if (in.element_type() != C64) { - return InvalidArgument("%s requires C64 input type, found %s.", + return InvalidArgument("%s requires complex input type, found %s.", FftType_Name(fft_type), PrimitiveType_Name(in.element_type())); } @@ -1773,6 +1865,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]); } } + if (ShapeUtil::IsZeroElementArray(in)) { + return in; + } Shape result = ShapeUtil::ChangeElementType(in, C64); result.set_dimensions(result.dimensions_size() - 1, fft_length[fft_rank - 1] / 2 + 1); @@ -1814,7 +1909,50 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, #undef RET_CHECK_RANK } -/* static */ StatusOr ShapeInference::InferCrossReplicaSumShape( +/* static */ StatusOr ShapeInference::InferTriangularSolveShape( + const Shape& a, const Shape& b, const TriangularSolveOptions& options) { + if (a.rank() < 2) { + return InvalidArgument( + "The 'a' argument to TriangularSolve must have rank >= 2, got shape %s", + a.ToString()); + } + if (b.rank() != a.rank()) { + return InvalidArgument( + "Arguments to triangular solve must have equal rank; got %s and %s.", + b.ToString(), a.ToString()); + } + if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) { + return InvalidArgument( + "The two minor dimensions of 'a' must have equal size, got %s.", + a.ToString()); + } + if (a.dimensions(a.rank() - 1) != + b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) { + return InvalidArgument( + "The shared dimension of 'a' and 'b' does not match, got shapes %s and " + "%s", + a.ToString(), b.ToString()); + } + absl::Span a_batch_dims(a.dimensions()); + absl::Span b_batch_dims(b.dimensions()); + a_batch_dims.remove_suffix(2); + b_batch_dims.remove_suffix(2); + if (a_batch_dims != b_batch_dims) { + return InvalidArgument( + "The leading batch dimensions of the arguments to triangular solve " + "must be equal; got %s and %s.", + b.ToString(), a.ToString()); + } + if (!TriangularSolveOptions_Transpose_IsValid(options.transpose_a()) || + options.transpose_a() == TriangularSolveOptions::TRANSPOSE_INVALID) { + return InvalidArgument( + "Invalid transpose option value for triangular solve (%d).\n", + options.transpose_a()); + } + return b; +} + +/* static */ StatusOr ShapeInference::InferAllReduceShape( absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RETURN_IF_ERROR( @@ -1834,12 +1972,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& shape, int64 split_dimension, int64 concat_dimension, int64 split_count) { TF_RET_CHECK(split_count > 0); - if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { + if (split_dimension >= shape.rank() || split_dimension < 0) { return InvalidArgument( "AllToAll split_dimension %d is out-of-bounds in shape %s.", split_dimension, ShapeUtil::HumanString(shape)); } - if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { + if (concat_dimension >= shape.rank() || concat_dimension < 0) { return InvalidArgument( "AllToAll concat_dimension %d is out-of-bounds in shape %s.", concat_dimension, ShapeUtil::HumanString(shape)); @@ -1877,7 +2015,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferCollectivePermuteShape( const Shape& shape) { - TF_RET_CHECK(ShapeUtil::IsArray(shape)); + TF_RET_CHECK(shape.IsArray()); return shape; } @@ -1901,7 +2039,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, for (int64 i = 1; i < num_reduced_args; ++i) { if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { return InvalidArgument( - "All reduced tensors must have the sime dimension. Tensor 0 has " + "All reduced tensors must have the same dimension. Tensor 0 has " "shape %s, Tensor %d has shape %s", ShapeUtil::HumanString(*reduced_args[0]), i, ShapeUtil::HumanString(*reduced_args[i])); @@ -1913,7 +2051,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // doesn't matter which one we choose. const Shape& arg = *reduced_args[0]; for (int64 dimension : dimensions_to_reduce) { - if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { + if (dimension >= arg.rank() || dimension < 0) { return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.", dimension, ShapeUtil::HumanString(arg)); } @@ -1930,20 +2068,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, std::set dimensions_to_reduce_set(dimensions_to_reduce.begin(), dimensions_to_reduce.end()); std::vector new_dimensions; - for (int i = 0; i < ShapeUtil::Rank(arg); ++i) { + std::vector new_is_dynamic; + for (int i = 0; i < arg.rank(); ++i) { if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) { new_dimensions.push_back(arg.dimensions(i)); + new_is_dynamic.push_back(arg.is_dynamic_dimension(i)); } } if (ShapeUtil::IsScalar(to_apply.result())) { return ShapeUtil::MakeShape(to_apply.result().element_type(), - new_dimensions); + new_dimensions, new_is_dynamic); } else { std::vector result_subshapes; for (const Shape& subshape : to_apply.result().tuple_shapes()) { - result_subshapes.push_back( - ShapeUtil::MakeShape(subshape.element_type(), new_dimensions)); + result_subshapes.push_back(ShapeUtil::MakeShape( + subshape.element_type(), new_dimensions, new_is_dynamic)); } return ShapeUtil::MakeTupleShape(result_subshapes); } @@ -2017,12 +2157,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(source_shape), ShapeUtil::HumanString(window_result_shape)); } + return operand_shape; } /* static */ StatusOr ShapeInference::InferGetDimensionSizeShape( const Shape& shape, int64 dimension) { - if (dimension < 0 || dimension >= ShapeUtil::Rank(shape)) { + if (dimension < 0 || dimension >= shape.rank()) { return InvalidArgument("GetDimensionSize dimension out of bounds: %d.", dimension); } @@ -2064,10 +2205,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, starts.size(), strides.size())); } - if (starts.size() != ShapeUtil::Rank(arg)) { + if (starts.size() != arg.rank()) { return InvalidArgument( "Slice index count does not match argument rank: %u vs %d.", - starts.size(), ShapeUtil::Rank(arg)); + starts.size(), arg.rank()); } std::vector sizes; @@ -2102,41 +2243,87 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferDynamicSliceShape( - const Shape& operand_shape, const Shape& start_indices_shape, - absl::Span slice_sizes) { + const Shape& operand_shape, absl::Span start_index_shapes, + absl::Span slice_sizes, bool allow_scalar_indices) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); - TF_RETURN_IF_ERROR( - ExpectArray(start_indices_shape, "start indices of dynamic slice")); + auto number_of_indices = start_index_shapes.size(); + // TODO(b/118437727): Remove this path. + if (!allow_scalar_indices || + (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) { + if (number_of_indices != 1) { + return InvalidArgument( + "Dynamic slice should have exactly 1 index operand, has %d.", + number_of_indices); + } - VLOG(2) << StrFormat( - "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", - ShapeUtil::HumanString(operand_shape), - ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", ")); + const Shape& start_indices_shape = start_index_shapes[0]; + VLOG(2) << StrFormat( + "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), + StrJoin(slice_sizes, ", ")); - if (ShapeUtil::Rank(start_indices_shape) != 1) { - return InvalidArgument( - "Dynamic slice start indices of rank %d must be rank1.", - ShapeUtil::Rank(start_indices_shape)); - } + TF_RETURN_IF_ERROR( + ExpectArray(start_indices_shape, "start indices of dynamic slice")); - if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { - return InvalidArgument( - "Dynamic slice start indices must be of integral type."); - } + if (start_indices_shape.rank() != 1) { + return InvalidArgument( + "Dynamic slice start indices of rank %d must be rank1.", + start_indices_shape.rank()); + } - const int64 start_num_dims = start_indices_shape.dimensions(0); - if (ShapeUtil::Rank(operand_shape) != start_num_dims) { - return InvalidArgument( - "Dynamic slice start number of dimensions %d (%s) must match rank " - "%d of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape), - ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { + return InvalidArgument( + "Dynamic slice start indices must be of integral type."); + } + + const int64 start_num_dims = start_indices_shape.dimensions(0); + if (operand_shape.rank() != start_num_dims) { + return InvalidArgument( + "Dynamic slice start number of dimensions %d (%s) must match rank " + "%d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + operand_shape.rank(), ShapeUtil::HumanString(operand_shape)); + } + } else { + VLOG(2) << StrFormat("slicing shape %s a with slice_sizes={%s}", + ShapeUtil::HumanString(operand_shape), + StrJoin(slice_sizes, ", ")); + + if (operand_shape.rank() != number_of_indices) { + return InvalidArgument( + "Dynamic slice start number of dimensions %d must match rank " + "%d of slice input (%s).", + number_of_indices, operand_shape.rank(), + ShapeUtil::HumanString(operand_shape)); + } + + if (number_of_indices > 0) { + const Shape& first_index_shape = start_index_shapes[0]; + if (!ShapeUtil::IsScalar(first_index_shape)) { + return InvalidArgument("Dynamic slice indices must be scalar, not %s.", + ShapeUtil::HumanString(first_index_shape)); + } + if (!ShapeUtil::ElementIsIntegral(first_index_shape)) { + return InvalidArgument( + "Dynamic slice start indices must be of integral type."); + } + for (const Shape& index_shape : start_index_shapes) { + if (!ShapeUtil::Compatible(first_index_shape, index_shape)) { + return InvalidArgument( + "Dynamic slice start indices must all have the same shape, got " + "mismatching indices with shapes %s and %s.", + ShapeUtil::HumanString(first_index_shape), + ShapeUtil::HumanString(index_shape)); + } + } + } } - if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) { + if (slice_sizes.size() != operand_shape.rank()) { return InvalidArgument( "Dynamic slice index count does not match argument rank: %u vs %d.", - slice_sizes.size(), ShapeUtil::Rank(operand_shape)); + slice_sizes.size(), operand_shape.rank()); } for (int64 dim = 0; dim < slice_sizes.size(); ++dim) { @@ -2159,46 +2346,92 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferDynamicUpdateSliceShape( const Shape& operand_shape, const Shape& update_shape, - const Shape& start_indices_shape) { + absl::Span start_index_shapes, bool allow_scalar_indices) { TF_RETURN_IF_ERROR( ExpectArray(operand_shape, "operand of dynamic update slice")); TF_RETURN_IF_ERROR( ExpectArray(update_shape, "update of dynamic update slice")); - TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, - "start indices of dynamic update slice")); - VLOG(2) << StrFormat( - "updating slice of shape %s at dynamic start_indices %s with update " - "shape %s", - ShapeUtil::HumanString(operand_shape), - ShapeUtil::HumanString(start_indices_shape), - ShapeUtil::HumanString(update_shape)); + auto number_of_indices = start_index_shapes.size(); + // TODO(b/118437727): Remove this path. + if (!allow_scalar_indices || + (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) { + if (number_of_indices != 1) { + return InvalidArgument( + "Dynamic update slice should have exactly 1 index operand, has %d.", + number_of_indices); + } + const Shape& start_indices_shape = start_index_shapes[0]; + TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, + "start indices of dynamic update slice")); - if (ShapeUtil::Rank(start_indices_shape) != 1) { - return InvalidArgument( - "Dynamic update slice start indices of rank %d must be rank1.", - ShapeUtil::Rank(start_indices_shape)); - } + VLOG(2) << StrFormat( + "updating slice of shape %s at dynamic start_indices %s with update " + "shape %s", + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::HumanString(update_shape)); - if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { - return InvalidArgument( - "Dynamic update slice start indices must be of integral type."); - } + if (start_indices_shape.rank() != 1) { + return InvalidArgument( + "Dynamic update slice start indices of rank %d must be rank1.", + start_indices_shape.rank()); + } - const int64 start_num_dims = start_indices_shape.dimensions(0); - if (ShapeUtil::Rank(operand_shape) != start_num_dims) { - return InvalidArgument( - "Dynamic update slice start number of dimensions %d (%s) must match " - "rank %d of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape), - ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { + return InvalidArgument( + "Dynamic update slice start indices must be of integral type."); + } + + const int64 start_num_dims = start_indices_shape.dimensions(0); + if (operand_shape.rank() != start_num_dims) { + return InvalidArgument( + "Dynamic update slice start number of dimensions %d (%s) must match " + "rank %d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + operand_shape.rank(), ShapeUtil::HumanString(operand_shape)); + } + } else { + VLOG(2) << StrFormat("updating slice of shape %s with update shape %s", + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(update_shape)); + + if (operand_shape.rank() != number_of_indices) { + return InvalidArgument( + "Dynamic update slice start number of dimensions %d must match " + "rank %d of slice input (%s).", + number_of_indices, operand_shape.rank(), + ShapeUtil::HumanString(operand_shape)); + } + + if (number_of_indices > 0) { + const Shape& first_index_shape = start_index_shapes[0]; + if (!ShapeUtil::IsScalar(first_index_shape)) { + return InvalidArgument( + "Dynamic update slice indices must be scalar, not %s.", + ShapeUtil::HumanString(first_index_shape)); + } + if (!ShapeUtil::ElementIsIntegral(first_index_shape)) { + return InvalidArgument( + "Dynamic update slice start indices must be of integral type."); + } + for (const Shape& index_shape : start_index_shapes) { + if (!ShapeUtil::Compatible(first_index_shape, index_shape)) { + return InvalidArgument( + "Dynamic update slice start indices must all have the same " + "shape, got mismatching indices with shapes %s and %s.", + ShapeUtil::HumanString(first_index_shape), + ShapeUtil::HumanString(index_shape)); + } + } + } } - if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) { + if (update_shape.rank() != operand_shape.rank()) { return InvalidArgument( "Dynamic update slice update rank does not match argument rank: " "%d vs %d.", - ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); + update_shape.rank(), operand_shape.rank()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, @@ -2210,7 +2443,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, PrimitiveType_Name(update_shape.element_type())); } - for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) { + for (int64 dim = 0; dim < operand_shape.rank(); ++dim) { const int64 input_dim_size = operand_shape.dimensions(dim); const int64 update_dim_size = update_shape.dimensions(dim); if (update_dim_size < 0) { @@ -2236,7 +2469,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument("a dimension number is duplicated in reverse"); } for (int64 dimension : dimensions) { - if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) { + if (dimension >= operand_shape.rank() || dimension < 0) { return InvalidArgument( "One of the reverse dimensions (%d) is out-of-bounds in shape %s.", dimension, ShapeUtil::HumanString(operand_shape)); @@ -2247,13 +2480,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferGetTupleElementShape( const Shape& arg, int64 index) { - if (!ShapeUtil::IsTuple(arg)) { + if (!arg.IsTuple()) { return InvalidArgument( "Cannot infer shape: attempting to index into non-tuple: %s.", ShapeUtil::HumanString(arg)); } - if (index >= arg.tuple_shapes_size()) { + if (index < 0 || index >= arg.tuple_shapes_size()) { return InvalidArgument( "Cannot infer shape: attempt to index out of tuple bounds: %d " ">= %d in shape %s.", @@ -2283,7 +2516,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, }; // Check the shapes of computation parameters and return types. - if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) { + if (!ShapeUtil::Equal(condition.result(), ShapeUtil::MakeShape(PRED, {}))) { return InvalidArgument("Condition must return a boolean; got %s.", shape_string()); } @@ -2303,7 +2536,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& predicate, const Shape& true_operand, const Shape& false_operand, const ProgramShape& true_computation, const ProgramShape& false_computation) { - if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { + if (!ShapeUtil::Equal(predicate, ShapeUtil::MakeShape(PRED, {}))) { return InvalidArgument("Predicate must be a boolean; got %s.", ShapeUtil::HumanString(predicate)); } @@ -2378,8 +2611,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast")); TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast")); - const int64 operand_rank = ShapeUtil::Rank(operand_shape); - const int64 output_rank = ShapeUtil::Rank(output_shape); + const int64 operand_rank = operand_shape.rank(); + const int64 output_rank = output_shape.rank(); if (operand_rank > output_rank) { return InvalidArgument( "InDim style broadcast must be to an equal or higher ranked shape; " @@ -2402,11 +2635,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, operand_shape.dimensions(i) != 1) { return InvalidArgument( "Input dimension should be either 1 or equal to the output dimension " - "it's broadcasting into; the %lldth operand dimension is %lld, the " + "it is broadcasting into; the %lldth operand dimension is %lld, the " "%lldth output dimension is %lld.", i, operand_shape.dimensions(i), broadcast_dimensions[i], output_shape.dimensions(broadcast_dimensions[i])); } + if (operand_shape.is_dynamic_dimension(i) != + output_shape.is_dynamic_dimension(broadcast_dimensions[i])) { + return InvalidArgument( + "Broadcast input and output dynamism mismatch: %s and %s", + operand_shape.ToString(), output_shape.ToString()); + } // Make sure the broadcast dimensions are listed in a strictly increasing // order. if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) { @@ -2438,9 +2677,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(inferred_shape)); } - std::vector indices(ShapeUtil::Rank(operand)); + std::vector indices(operand.rank()); std::iota(indices.begin(), indices.end(), 0); - if (dimensions.size() != ShapeUtil::Rank(operand) || + if (dimensions.size() != operand.rank() || !std::is_permutation(dimensions.begin(), dimensions.end(), indices.begin())) { return InvalidArgument( @@ -2449,6 +2688,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, StrJoin(dimensions, ","), ShapeUtil::HumanString(operand)); } + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(operand, inferred_shape); + for (auto& unmodified : unmodified_dims) { + if (operand.is_dynamic_dimension(unmodified.first)) { + inferred_shape.set_dynamic_dimension(unmodified.second, true); + } + } + return inferred_shape; } @@ -2456,11 +2703,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& operand, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); - std::vector indices(ShapeUtil::Rank(operand)); - std::iota(indices.begin(), indices.end(), 0); - if (dimensions.size() != ShapeUtil::Rank(operand) || - !std::is_permutation(dimensions.begin(), dimensions.end(), - indices.begin())) { + if (!IsPermutation(dimensions, operand.rank())) { return InvalidArgument( "Transpose dimensions [%s] are not a permutation of the operand " "dimensions (operand shape is %s).", @@ -2522,19 +2765,31 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "Select's pred operand must have PRED element type; got %s.", ShapeUtil::HumanString(pred)); } - if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || + if (Shape::Equal() + .IgnoreElementType() + .IgnoreLayout() + .IgnoreDynamicDimension()(pred, on_true) || ShapeUtil::IsScalar(pred)) { // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. - return ShapeUtil::ChangeElementType( + Shape inferred_shape = ShapeUtil::ChangeElementType( on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false)); - } else { - return InvalidArgument( - "Select operation with non-scalar predicate with dimensionality " - " different from the other operands: %s.", - ShapeUtil::HumanString(pred)); + + // Propagate dynamic dimensions if pred is not a scalar. + if (!ShapeUtil::IsScalar(pred)) { + for (int i = 0; i < inferred_shape.rank(); i++) { + if (pred.is_dynamic_dimension(i)) { + inferred_shape.set_dynamic_dimension(i, true); + } + } + } + return inferred_shape; } + return InvalidArgument( + "Select operation with non-scalar predicate with dimensionality " + "different from the other operands: %s.", + ShapeUtil::HumanString(pred)); } /* static */ StatusOr ShapeInference::InferTupleSelectShape( @@ -2810,7 +3065,7 @@ Status ValidateScatterDimensionNumbers( "update_window_dims in scatter op must not repeat; got: %s.", StrJoin(dim_numbers.update_window_dims(), ", ")); } - const int64 updates_rank = ShapeUtil::Rank(updates_shape); + const int64 updates_rank = updates_shape.rank(); for (int64 window_dim : dim_numbers.update_window_dims()) { if (window_dim < 0 || window_dim >= updates_rank) { return InvalidArgument( @@ -2844,10 +3099,10 @@ Status ValidateScatterDimensionNumbers( // Validate window size. auto window_size = dim_numbers.update_window_dims_size() + dim_numbers.inserted_window_dims_size(); - if (window_size != ShapeUtil::Rank(operand_shape)) { + if (window_size != operand_shape.rank()) { return InvalidArgument( "Scatter op has window of size %d; doesn't match operand of rank %d.", - window_size, ShapeUtil::Rank(operand_shape)); + window_size, operand_shape.rank()); } // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. @@ -2932,10 +3187,9 @@ Status ValidateScatterDimensionNumbers( int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + scatter_dim_numbers.update_window_dims_size(); - if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { + if (updates_shape.rank() != expected_updates_rank) { return InvalidArgument("Updates tensor must be of rank %d; got %d.", - expected_updates_rank, - ShapeUtil::Rank(updates_shape)); + expected_updates_rank, updates_shape.rank()); } TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers( @@ -2966,7 +3220,7 @@ Status ValidateScatterDimensionNumbers( } int64 scatter_dims_seen = 0; - for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) { + for (int64 i = 0; i < updates_shape.rank(); ++i) { bool is_update_window_dim = absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i); if (is_update_window_dim) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index d94385a04d50baff8156570a09620fd458547936..acb071ab18824472153fc608b812ad2d9c52651e 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -109,16 +109,20 @@ class ShapeInference { // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr InferConvolveShape( const Shape& lhs, const Shape& rhs, int64 feature_group_count, - const Window& window, + int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); // Infers the shape produced by the given FFT type on the given operand. static StatusOr InferFftShape(const Shape& in, FftType fft_type, absl::Span fft_length); + // Infers the shape produced by the given triangular solve operation. + static StatusOr InferTriangularSolveShape( + const Shape& a, const Shape& b, const TriangularSolveOptions& options); + // Infers the shape produced by a cross replica sum with the given operand // shapes. - static StatusOr InferCrossReplicaSumShape( + static StatusOr InferAllReduceShape( absl::Span operand_shapes); // Infers final shape of an Alltoall operation that is created by the xla @@ -176,14 +180,15 @@ class ShapeInference { // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. static StatusOr InferDynamicSliceShape( - const Shape& operand_shape, const Shape& start_indices_shape, - absl::Span slice_sizes); + const Shape& operand_shape, absl::Span start_index_shapes, + absl::Span slice_sizes, bool allow_scalar_indices = true); // Infers the shape produced by a dynamic update slice operation based // on the shape of operand and update. static StatusOr InferDynamicUpdateSliceShape( const Shape& operand_shape, const Shape& update_shape, - const Shape& start_indices_shape); + absl::Span start_index_shapes, + bool allow_scalar_indices = true); // Infers the shape produced by doing a compile-time-constant indexing into // the given input shape. This is essential for operations on tuples, because diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 4639e32db4d59080a9e85e46983fac61d9e76be9..f400ef51f07b006eef2ea674feff1dd72f836e77 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -35,6 +35,7 @@ class ShapeInferenceTest : public ::testing::Test { protected: // Some handy scalar shapes. const Shape s32_ = ShapeUtil::MakeShape(S32, {}); + const Shape f16_ = ShapeUtil::MakeShape(F16, {}); const Shape f32_ = ShapeUtil::MakeShape(F32, {}); const Shape f64_ = ShapeUtil::MakeShape(F64, {}); const Shape pred_ = ShapeUtil::MakeShape(PRED, {}); @@ -251,7 +252,7 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) { TEST_F(ShapeInferenceTest, Complex) { auto complex_shape = [&](const Shape& lhs, const Shape& rhs, - const absl::Span& bcast) { + absl::Span bcast) { return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, bcast); }; @@ -260,8 +261,8 @@ TEST_F(ShapeInferenceTest, Complex) { ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok()); // Component types must match. ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok()); - // Only F32->C64 supported. - ASSERT_FALSE(complex_shape(f64_, f64_, {}).ok()); + // Only F32->C64 and F64->C128 supported. + ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok()); // Validate correct uses. Shape c64_32 = ShapeUtil::MakeShape(C64, {32}); TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {})); @@ -285,6 +286,9 @@ TEST_F(ShapeInferenceTest, Complex) { ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {})); ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); + + TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {}))); } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { @@ -420,7 +424,8 @@ TEST_F(ShapeInferenceTest, Convolve) { dim1->set_window_dilation(1); dim1->set_base_dilation(1); auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); + lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), @@ -465,7 +470,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dim1->set_window_dilation(2); dim1->set_base_dilation(1); auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); + lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}), @@ -510,7 +516,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dim1->set_window_dilation(1); dim1->set_base_dilation(2); auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); + lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}), @@ -548,7 +555,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dim1->set_padding_low(1); dim1->set_padding_high(1); auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); + lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("each dimension exactly once")); @@ -888,6 +896,20 @@ TEST_F(ShapeInferenceTest, InferConstIndexShape) { ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie())); } +TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) { + Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_}); + auto inferredNegative_status = + ShapeInference::InferGetTupleElementShape(tuple_shape, -1); + auto inferred2_status = + ShapeInference::InferGetTupleElementShape(tuple_shape, 2); + ASSERT_FALSE(inferredNegative_status.ok()); + ASSERT_FALSE(inferred2_status.ok()); + EXPECT_THAT(inferredNegative_status.status().error_message(), + HasSubstr("attempt to index out of tuple bounds")); + EXPECT_THAT(inferred2_status.status().error_message(), + HasSubstr("attempt to index out of tuple bounds")); +} + TEST_F(ShapeInferenceTest, InferPowShape) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = ShapeInference::InferBinaryOpShape( @@ -1002,9 +1024,9 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Batch and contracting dimension number mismatch")); + EXPECT_TRUE(inferred_status.ok()); + EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), + ShapeUtil::MakeShape(F32, {32, 32, 64}))); } // vector vector -> scalar @@ -1096,7 +1118,6 @@ TEST_F(ShapeInferenceTest, DotGeneral) { TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); - Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(2); @@ -1110,8 +1131,28 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Must specify one contracting dimension for both " - "lhs and rhs")); + HasSubstr("Must specify the same number of contracting " + "dimensions for lhs and rhs.")); +} + +TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + EXPECT_TRUE(inferred_status.ok()); + EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape)); } // BatchMatMul with different batch dimension sizes fails. @@ -1130,11 +1171,11 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Batch dimension numbers and sizes must match")); + HasSubstr("Batch dimension sizes must match")); } -// BatchMatMul with different batch dimension numbers fails. -TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { +// BatchMatMul with different batch dimension numbers passes +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersPasses) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); @@ -1147,9 +1188,9 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { auto inferred_status = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Batch dimension numbers must precede non-batch")); + ASSERT_TRUE(inferred_status.ok()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), + ShapeUtil::MakeShape(F32, {2, 11, 14}))); } // BatchMatMul with out-of-range dimension numbers fails. @@ -1440,6 +1481,14 @@ TEST_F(ShapeInferenceTest, Pad) { Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape)); + + dimension1->set_edge_padding_low(-20); + dimension1->set_edge_padding_high(-10); + auto negative_dimension_size = ShapeInference::InferPadShape( + input_shape, padding_value_shape, padding_config); + ASSERT_FALSE(negative_dimension_size.ok()); + ASSERT_THAT(negative_dimension_size.status().error_message(), + HasSubstr("negative size for dimension 1")); } TEST_F(ShapeInferenceTest, Reverse) { @@ -1523,6 +1572,16 @@ TEST_F(ShapeInferenceTest, Transpose) { ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); } +TEST_F(ShapeInferenceTest, Rank1Transpose) { + Shape a_shape = ShapeUtil::MakeShape(F32, {5}); + auto inferred_shape_and_status = + ShapeInference::InferTransposeShape(a_shape, {0}); + EXPECT_IS_OK(inferred_shape_and_status); + Shape inferred_shape = inferred_shape_and_status.ValueOrDie(); + EXPECT_TRUE( + ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5}))); +} + TEST_F(ShapeInferenceTest, Conditional) { auto inferred_status0 = ShapeInference::InferConditionalShape( pred_, vector_32_, vector_64_, diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 28a30b5ee2dbcb5012804578d4d037c241045309..d90dde3b13d3aa9e1de10dd9e1d11a8e6da170de 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -85,7 +85,7 @@ string ShapedBuffer::ToString() const { on_device_shape(), [this, &s](const Shape& subshape, const ShapeIndex& index) { string shape_str; - if (ShapeUtil::IsTuple(subshape)) { + if (subshape.IsTuple()) { shape_str = "tuple"; } else { shape_str = ShapeUtil::HumanStringWithLayout(subshape); diff --git a/tensorflow/compiler/xla/service/sort_simplifier.cc b/tensorflow/compiler/xla/service/sort_simplifier.cc new file mode 100644 index 0000000000000000000000000000000000000000..122366a0f322a66963b364e1b19629cbd2d9aabe --- /dev/null +++ b/tensorflow/compiler/xla/service/sort_simplifier.cc @@ -0,0 +1,165 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/sort_simplifier.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { +namespace { + +// If the sort instruction has a tuple shape then looks for unused output +// values and removes them from the sort instruction. Returns true if the +// graph has been modified. +StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { + if (!sort->shape().IsTuple()) { + return false; + } + + HloComputation* computation = sort->parent(); + + if (computation->root_instruction() == sort) { + // Can't analyse users of the root instruction. + return false; + } + + absl::flat_hash_set used_indices; + for (const HloInstruction* user : sort->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + // Can't analyse users other then get-tuple-element. + return false; + } + used_indices.insert(user->tuple_index()); + } + + // Also note which parameters are used by the comparator computation. + auto comparator = sort->to_apply(); + for (int64 i = 0; i < sort->operand_count() * 2; ++i) { + if (comparator->parameter_instruction(i)->user_count() > 0) { + // operand i corresponds to parameters 2 * i and 2 * i + 1 of the + // computation. + used_indices.insert(i / 2); + } + } + + if (used_indices.size() == sort->operand_count()) { + // All operands are used. + return false; + } + + std::vector operands; + std::vector new_shapes; + for (int64 i = 0; i < sort->operand_count(); ++i) { + if (used_indices.contains(i)) { + operands.push_back(sort->mutable_operand(i)); + new_shapes.push_back(sort->operand(i)->shape()); + } + } + + Shape new_sort_shape = new_shapes.size() == 1 + ? new_shapes[0] + : ShapeUtil::MakeTupleShape(new_shapes); + HloInstruction* new_sort = computation->AddInstruction( + sort->CloneWithNewOperands(new_sort_shape, operands)); + absl::flat_hash_map> + replacements; + int64 parameter_number = 0; + for (int64 i = 0; i < sort->operand_count(); ++i) { + auto* old_lhs_parameter = comparator->parameter_instruction(i * 2); + auto* old_rhs_parameter = comparator->parameter_instruction(i * 2 + 1); + if (used_indices.contains(i)) { + Shape scalar_shape = + ShapeUtil::MakeShape(sort->operand(i)->shape().element_type(), {}); + replacements[old_lhs_parameter] = HloInstruction::CreateParameter( + parameter_number, scalar_shape, + absl::StrCat("p.", parameter_number / 2, ".lhs")); + ++parameter_number; + replacements[old_rhs_parameter] = HloInstruction::CreateParameter( + parameter_number, scalar_shape, + absl::StrCat("p.", parameter_number / 2, ".rhs")); + ++parameter_number; + } else { + replacements[old_lhs_parameter] = nullptr; + replacements[old_rhs_parameter] = nullptr; + } + } + HloModule* module = sort->GetModule(); + HloComputation* new_compare = module->AddEmbeddedComputation( + comparator->CloneWithReplacements(std::move(replacements))); + new_sort->set_to_apply(new_compare); + + // Map from original get-tuple-element tuple index to new HLO instruction + absl::flat_hash_map result_map; + if (new_sort->shape().IsTuple()) { + // Old sort key maps to new sort key. + int64 new_index = 0; + for (int64 i = 0; i < sort->operand_count(); ++i) { + if (used_indices.count(i)) { + result_map[i] = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_shapes[new_index], new_sort, new_index)); + ++new_index; + } + } + } else { + CHECK_EQ(used_indices.size(), 1); + result_map[*used_indices.begin()] = new_sort; + } + std::vector users(sort->users().begin(), + sort->users().end()); + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR( + user->ReplaceAllUsesWith(result_map.at(user->tuple_index()))); + TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(user)); + } + return true; +} +} // namespace + +StatusOr SortSimplifier::Run(HloModule* module) { + VLOG(2) << "HLO module before SortSimplifier:"; + XLA_VLOG_LINES(2, module->ToString()); + + bool changed = false; + std::vector sort_instrs; + for (auto* comp : module->MakeNonfusionComputations()) { + absl::c_copy_if(comp->instructions(), std::back_inserter(sort_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kSort; + }); + } + + for (HloInstruction* sort_instr : sort_instrs) { + TF_ASSIGN_OR_RETURN(bool result, RemoveUnusedOperandFromSort(sort_instr)); + changed |= result; + } + + if (changed) { + VLOG(2) << "HLO module after SortSimplifier:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after SortSimplifier"; + } + + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/sort_simplifier.h b/tensorflow/compiler/xla/service/sort_simplifier.h new file mode 100644 index 0000000000000000000000000000000000000000..8c6f313aa04f51e14a14450bc72fc622d74133a4 --- /dev/null +++ b/tensorflow/compiler/xla/service/sort_simplifier.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SORT_SIMPLIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SORT_SIMPLIFIER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which removes unused operands from sort, where an unused operand is +// defined as an operand at some index 'x' at which the output is not used. +class SortSimplifier : public HloModulePass { + public: + absl::string_view name() const override { return "simplify-sorts"; } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SORT_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/sort_simplifier_test.cc b/tensorflow/compiler/xla/service/sort_simplifier_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..696ac1b465848894f8dcb1c88bc48c6a5b268ef4 --- /dev/null +++ b/tensorflow/compiler/xla/service/sort_simplifier_test.cc @@ -0,0 +1,160 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/sort_simplifier.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace m = match; + +using SortSimplifierTest = HloTestBase; + +TEST_F(SortSimplifierTest, RemoveUnusedSortOperandArrayResult) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + SortSimplifier simplifier; + uint64 num_executions = 0; + do { + num_executions++; + } while (simplifier.Run(module.get()).ValueOrDie()); + EXPECT_EQ(num_executions, 2); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Sort(m::Parameter(0)))); +} + +TEST_F(SortSimplifierTest, RemoveUnusedSortOperandTuple) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.2.lhs = u32[] parameter(4) + p.2.rhs = u32[] parameter(5) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,87] parameter(0) + values.0 = s32[64,87] parameter(1) + values.1 = u32[64,87] parameter(2) + sort = (f32[64,87], s32[64,87], u32[64,87]) sort( + keys, values.0, values.1), + dimensions={1}, to_apply=compare + gte.0 = f32[64,87] get-tuple-element(sort), index=0 + gte.1 = u32[64,87] get-tuple-element(sort), index=2 + ROOT tuple = (f32[64,87], u32[64,87]) tuple(gte.0, gte.1) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + SortSimplifier simplifier; + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + GmockMatch(m::Tuple( + m::GetTupleElement(m::Sort(m::Parameter(0), m::Parameter(2)), 0), + m::GetTupleElement(m::Sort(m::Parameter(0), m::Parameter(2)), 1)))); +} + +TEST_F(SortSimplifierTest, DontRemoveUnusedSortKey) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1}, to_apply=compare + ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + SortSimplifier simplifier; + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + +TEST_F(SortSimplifierTest, RemoveUnusedFirstOperand) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.1.lhs, p.1.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare + ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + SortSimplifier simplifier; + uint64 num_executions = 0; + do { + num_executions++; + } while (simplifier.Run(module.get()).ValueOrDie()); + EXPECT_EQ(num_executions, 2); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Sort(m::Parameter(1)))); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/stable_sort_expander.cc b/tensorflow/compiler/xla/service/stable_sort_expander.cc new file mode 100644 index 0000000000000000000000000000000000000000..1aa7e5fe7c0d57ee3303480e4727c456727f64c8 --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander.cc @@ -0,0 +1,204 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Looks for a iota operand that can be used as tie breaker in the computation. +// If no matching iota operand is found, a iota operand is added to Sort. The +// comparison computation is adjusted to break ties using the values from the +// iota operand. +StatusOr StableSortExpander::ExpandInstruction( + HloInstruction* instruction) { + auto* sort = Cast(instruction); + HloComputation* computation = sort->parent(); + + HloInstruction* expanded_sort = nullptr; + absl::flat_hash_set used_indices; + int64 iota_index = -1; + for (const HloInstruction* operand : sort->operands()) { + // We can only use the iota operand if it has an iota dimension which is the + // same as the dimension to sort. Also it should have an integral type that + // is large enough for the number of elements in the sort dimension. For + // now, we only allow S32, because we expect to find a S32 iota operand for + // all Sort ops which are created by TopK. + // TODO(b/122298745): Also support other types. + if (operand->opcode() == HloOpcode::kIota && + Cast(operand)->iota_dimension() == + sort->sort_dimension() && + operand->shape().element_type() == S32) { + iota_index = sort->operand_index(operand); + break; + } + } + + // If there is currently no iota operand which we could use for making the + // sort stable, we will have to add a new such operand. + if (iota_index == -1) { + Shape iota_shape = sort->operand(0)->shape(); + // We might need to use S64 if the number of elements in the sort dimension + // is bigger than 2^31 - 1. + // TODO(b/122298745): Handle Sort ops where S32 is too small for the number + // of elements in the sort dimension. + if (iota_shape.dimensions(sort->sort_dimension()) > + std::numeric_limits::max()) { + return Unimplemented( + "Stable sorting of more than 2^31-1 elements is not implemented"); + } + iota_shape.set_element_type(S32); + auto iota = computation->AddInstruction( + HloInstruction::CreateIota(iota_shape, sort->sort_dimension())); + + // Create a new comparator. + auto comparator = sort->to_apply(); + absl::flat_hash_map> + replacements; + std::vector> extra_parameters; + std::vector extra_parameter_ptrs; + Shape scalar_shape = ShapeUtil::MakeShape(S32, {}); + extra_parameters.push_back(HloInstruction::CreateParameter( + sort->operand_count() * 2, scalar_shape, + absl::StrCat("p.", sort->operand_count(), ".lhs"))); + extra_parameter_ptrs.push_back(extra_parameters.back().get()); + extra_parameters.push_back(HloInstruction::CreateParameter( + sort->operand_count() * 2 + 1, scalar_shape, + absl::StrCat("p.", sort->operand_count(), ".rhs"))); + extra_parameter_ptrs.push_back(extra_parameters.back().get()); + sort->set_to_apply(sort->GetModule()->AddEmbeddedComputation( + comparator->CloneWithReplacements(std::move(replacements), + extra_parameter_ptrs))); + + // Replace the original sort op. + std::vector new_operands(sort->operands().begin(), + sort->operands().end()); + new_operands.push_back(iota); + std::vector new_shapes = sort->operand_count() == 1 + ? std::vector{sort->shape()} + : sort->shape().tuple_shapes(); + new_shapes.push_back(iota_shape); + Shape new_sort_shape = ShapeUtil::MakeTupleShape(new_shapes); + HloInstruction* new_sort = computation->AddInstruction( + sort->CloneWithNewOperands(new_sort_shape, new_operands)); + + // Add a "wrapper" around the new sort op to make sure we have the same + // shape as before. For the rank 1 case, we only need a GetTupleElement, + // otherwise we create a Tuple consisting of GetTupleElements of the new + // sort. + std::vector tuple_elements; + tuple_elements.reserve(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + tuple_elements.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + sort->operand(i)->shape(), new_sort, i))); + } + expanded_sort = tuple_elements[0]; + if (tuple_elements.size() > 1) { + expanded_sort = computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); + } + sort = Cast(new_sort); + iota_index = sort->operand_count() - 1; + } + + // Modify the computation to break ties using the iota operand. + auto comparator = sort->to_apply(); + std::vector instructions_postorder = + comparator->MakeInstructionPostOrder(); + absl::flat_hash_map replacements; + // Look up instr in the replacements map, and return either the replacement, + // or instr, if the replacement isn't present. + auto replace = [&](HloInstruction* instr) { + auto it = replacements.find(instr); + if (it == replacements.end()) { + return instr; + } + return it->second; + }; + HloInstruction* old_root = comparator->root_instruction(); + // The comparison computation gets 2 * n parameters (n being the number of + // operands of Sort), where parameters 2 * i and 2 * i + 1 correspond to two + // different scalars of operand i of Sort which are to be compared. The + // comparison computation should induce a strict weak order, so if + // to_apply(p1.lhs, p1.rhs, ..., pn.lhs, pn.rhs) is equal to + // to_apply(p1.rhs, p1.lhs, ..., pn.rhs, pn.lhs), we can conclude that the + // values to be compared are equivalent, and perform a tie-breaker comparison. + // + // We clone each instruction with at least one operand, but use as new + // operands of the instruction the replacements of the original operands. + // Parameter 2 * i is replaced by parameter 2 * i + 1 and vice versa. This + // should make sure that the cloned root instruction gives the result of the + // comparison computation when being called with each scalar pair reversed. + // parameters corresponding to the iota operand. + for (int64 i = 0; i < comparator->num_parameters(); ++i) { + replacements[comparator->parameter_instruction(i)] = + comparator->parameter_instruction(i ^ 1); + } + HloInstruction* cloned_root = nullptr; + for (HloInstruction* inst : instructions_postorder) { + if (inst->operand_count() == 0) { + continue; + } + std::vector new_operands; + new_operands.reserve(inst->operand_count()); + for (HloInstruction* operand : inst->operands()) { + new_operands.push_back(replace(operand)); + } + auto new_instruction = + inst->CloneWithNewOperands(inst->shape(), new_operands); + replacements[inst] = new_instruction.get(); + if (inst == old_root) { + cloned_root = new_instruction.get(); + } + comparator->AddInstruction(std::move(new_instruction)); + } + CHECK_NE(cloned_root, nullptr); + Shape scalar_pred = ShapeUtil::MakeShape(PRED, {}); + HloInstruction* same = + comparator->AddInstruction(HloInstruction::CreateBinary( + scalar_pred, HloOpcode::kEq, old_root, cloned_root)); + HloInstruction* tie_breaker = + comparator->AddInstruction(HloInstruction::CreateBinary( + scalar_pred, HloOpcode::kLt, + comparator->parameter_instruction(2 * iota_index), + comparator->parameter_instruction(2 * iota_index + 1))); + HloInstruction* new_root = + comparator->AddInstruction(HloInstruction::CreateTernary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker, + old_root)); + comparator->set_root_instruction(new_root); + + return expanded_sort; +} + +bool StableSortExpander::InstructionMatchesPattern( + HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSort && + Cast(instruction)->is_stable(); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/stable_sort_expander.h b/tensorflow/compiler/xla/service/stable_sort_expander.h new file mode 100644 index 0000000000000000000000000000000000000000..31b6fd92d25370218017c58072f1aa5e64df00c3 --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which expands Sort ops that have the is_stable field set to true +// into equivalent Sort ops which guarantee stable sorting without relying on +// the is_stable field. +class StableSortExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "stable-sort-expander"; } + + private: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/stable_sort_expander_test.cc b/tensorflow/compiler/xla/service/stable_sort_expander_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a62d953e6e8fa2f3c1ecfd9e4a7900eee74f9dca --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander_test.cc @@ -0,0 +1,358 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" + +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace m = match; + +using StableSortExpanderTest = HloTestBase; + +// Checks whether 'a' and 'b' are roots of equivalent computations, except that +// parameters 2 * i and 2 * i + 1 are switched. +bool IsSameComputationExceptParams(const HloInstruction* a, + const HloInstruction* b) { + if (a->opcode() != b->opcode() || a->operand_count() != b->operand_count()) { + return false; + } + if (a->opcode() == HloOpcode::kParameter) { + // Check that parameters were switched. + return a->parameter_number() == (b->parameter_number() ^ 1); + } + // If the operation has no operands, it should actually be the same. + if (a->operand_count() == 0) { + return a == b; + } + // Otherwise recursively compare all operands. + for (int64 i = 0; i < a->operand_count(); ++i) { + if (!IsSameComputationExceptParams(a->operand(i), b->operand(i))) { + return false; + } + } + return true; +} + +// Check that the comparison computation has been modified to add a tie breaker +// using 'iota_parameter'. +void CheckComputationHasTieBreaker(const HloInstruction* root, + int64 iota_parameter) { + // With the tie breaker, the root instruction should be + // Select(Eq(Comp(), CompReverse()), Lt(), Comp()) + // with Comp() being the original comparison function, and CompReverse() being + // the copied comparison function where the parameters are reversed. Lt() is + // the tie breaker comparison using the Iota operand. + ASSERT_EQ(root->opcode(), HloOpcode::kSelect); + ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kEq); + + // Check that the tie breaker instruction is correct. + EXPECT_THAT(root->operand(1), + GmockMatch(m::Lt(m::Parameter(iota_parameter * 2), + m::Parameter(iota_parameter * 2 + 1)))); + EXPECT_EQ(root->operand(2), root->operand(0)->operand(0)); + + // Check that Comp() and CompReverse() are equivalent except that + // CompReverse() has reversed parameters. + EXPECT_TRUE(IsSameComputationExceptParams(root->operand(0)->operand(0), + root->operand(0)->operand(1))); +} + +TEST_F(StableSortExpanderTest, StabilizeSortReuseIotaOperand) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, + StabilizeSortReuseIotaOperandComplicatedComparison) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + max = u32[] constant(2147483647) + zero = s32[] constant(0) + lhs.signed = s32[] bitcast-convert(p.0.lhs) + lhs.unsigned = u32[] bitcast-convert(p.0.lhs) + lhs.flipped = u32[] subtract(max, lhs.unsigned) + lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped) + lhs.is_negative = pred[] less-than(lhs.flipped.signed, zero) + lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed) + rhs.signed = s32[] bitcast-convert(p.0.rhs) + rhs.unsigned = u32[] bitcast-convert(p.0.rhs) + rhs.flipped = u32[] subtract(max, rhs.unsigned) + rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped) + rhs.is_negative = pred[] less-than(rhs.flipped.signed, zero) + rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed) + ROOT lt = pred[] less-than(lhs.converted, rhs.converted) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, StabilizeSortAddIotaOperandAndChangeRoot) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + ROOT sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, GmockMatch(m::Tuple( + m::GetTupleElement( + m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 0), + m::GetTupleElement( + m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 1)))); + CheckComputationHasTieBreaker( + root->operand(0)->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, HonorIsStableFlag) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=false + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_FALSE(stabilizer.Run(module.get()).ValueOrDie()); +} + +TEST_F(StableSortExpanderTest, + StabilizeSortDontReuseIotaOperandWrongDimension) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=0 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + // Simplify away the "wrapper" tuple around the new sort. + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions( + [](const Shape&, const Shape&) { return false; })); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, StabilizeSortDontReuseIotaOperandWrongType) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = f32[] parameter(2) + p.1.rhs = f32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = f32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + // Simplify away the "wrapper" tuple around the new sort. + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions( + [](const Shape&, const Shape&) { return false; })); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, StabilizeSortR1) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + mask = s32[] constant(65535) + lhs = s32[] and(p.0.lhs, mask) + rhs = s32[] and(p.0.rhs, mask) + ROOT lt = pred[] less-than(lhs, rhs) + } + + ENTRY sort_computation { + keys = s32[64,8732]{1,0} parameter(0) + ROOT sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare, + is_stable=true + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, StabilizeSortR1NoRoot) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + mask = s32[] constant(65535) + lhs = s32[] and(p.0.lhs, mask) + rhs = s32[] and(p.0.rhs, mask) + ROOT lt = pred[] less-than(lhs, rhs) + } + + ENTRY sort_computation { + keys = s32[64,8732]{1,0} parameter(0) + sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare, + is_stable=true + ROOT neg = s32[64,8732]{1,0} negate(sort) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Negate(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0)))); + CheckComputationHasTieBreaker( + root->operand(0)->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/1); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index a21e586efadb85d18e88e44999283b28f7f65eac..15ef623cc7b2dbc31e9cba5c4783c39b8805a5aa 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -142,7 +142,7 @@ Status TransferManager::TransferArrayToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest) { const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); - TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) + TF_RET_CHECK(on_device_shape.IsArray()) << "On-device representation of " << ShapeUtil::HumanString(literal.shape()) << " is not an array: " << ShapeUtil::HumanString(on_device_shape); @@ -227,7 +227,7 @@ Status TransferManager::WriteTupleIndexTablesAsync( return ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_device_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { - if (ShapeUtil::IsTuple(device_subshape)) { + if (device_subshape.IsTuple()) { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); @@ -248,6 +248,22 @@ Status TransferManager::WriteTupleIndexTablesAsync( }); } +Status TransferManager::WriteRootTupleIndexTable( + se::Stream* stream, const ShapedBuffer& device_buffer) { + TF_RET_CHECK(device_buffer.on_device_shape().IsTuple()); + se::DeviceMemoryBase device_memory = device_buffer.buffer({}); + TF_RET_CHECK(GetByteSizeRequirement(device_buffer.on_device_shape()) == + device_memory.size()); + + std::vector elements; + for (int64 i = 0; + i < ShapeUtil::TupleElementCount(device_buffer.on_device_shape()); ++i) { + elements.push_back(device_buffer.buffer({i})); + } + return WriteSingleTupleIndexTable( + stream, elements, device_buffer.on_device_shape(), &device_memory); +} + Status TransferManager::TransferBufferFromDevice( se::Stream* stream, const se::DeviceMemoryBase& source, int64 size, void* destination) { diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 49f0b8f8b72001f07200d3e94828f60fcb0fa8fb..43a50487c636da75224547286a31625db3f91330 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -146,6 +146,12 @@ class TransferManager { Status WriteTupleIndexTablesAsync(se::Stream* stream, const ShapedBuffer& device_buffer); + // Writes a tuple index buffer for the root of 'device_buffer', which must + // be a tuple. Unlike WriteTupleIndexTables, only writes the root buffer, + // rather than writing all subbuffers. This method is always asynchronous. + Status WriteRootTupleIndexTable(se::Stream* stream, + const ShapedBuffer& device_buffer); + // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory // region for a host-to-device transfer. diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 7c1f4b5cc67dd2a84271b4f2b8015fdb2ff6e846..a95ca2bf2a8fcd700eb9234cafbfce9b62f2370c 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -45,7 +45,7 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( auto& operand = *dot.operand(i); if (operand.IsRank2Transpose()) { operand_set.push_back(i); - } else if (ShapeUtil::Rank(operand.shape()) != 2) { + } else if (operand.shape().rank() != 2) { return {}; } } @@ -130,8 +130,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { HloInstruction* new_lhs; const int64 kLhsIdx = 0; - if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) != - operand_indices.end()) { + if (absl::c_linear_search(operand_indices, kLhsIdx)) { HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx); const auto& transpose_dimensions = transpose.dimensions(); HloInstruction& transpose_operand = *transpose.mutable_operand(0); @@ -154,8 +153,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { HloInstruction* new_rhs; const int64 kRhsIdx = 1; - if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) != - operand_indices.end()) { + if (absl::c_linear_search(operand_indices, kRhsIdx)) { HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx); const auto& transpose_dimensions = transpose.dimensions(); HloInstruction& transpose_operand = *transpose.mutable_operand(0); @@ -178,7 +176,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto new_conv = HloInstruction::CreateConvolve( convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(), - convolution.window(), new_dnums, convolution.precision_config()); + convolution.batch_group_count(), convolution.window(), new_dnums, + convolution.precision_config()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 17cdaa74fc328d156292f5af828d4222a9a01f1f..f8a5fa0215007310d6bec35d20fc643afc824dda 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -139,9 +139,9 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { HloModule FoldDotTransposeConstant ENTRY entry_computation { - constant = f32[2,1]{1,0} constant(f32[2,1] { { 1 }, { 2 } }) + constant = f32[2,1]{1,0} constant({ { 1 }, { 2 } }) transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0} - constant.1 = f32[3,2]{1,0} constant(f32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 } }) + constant.1 = f32[3,2]{1,0} constant({ { 1, 2 }, { 3, 4 }, { 5, 6 } }) transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0} ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} } @@ -240,12 +240,13 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, - dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = @@ -295,12 +296,13 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, - dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = @@ -355,12 +357,13 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, - dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = @@ -421,12 +424,13 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, - dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve.cc b/tensorflow/compiler/xla/service/triangular_solve_expander.cc similarity index 75% rename from tensorflow/compiler/xla/client/lib/triangular_solve.cc rename to tensorflow/compiler/xla/service/triangular_solve_expander.cc index c5a1d34cc66e6f8c1a832f8a8437163b846a5431..b26cdc1db59b30d82b9ac58a8a2ac762220086be 100644 --- a/tensorflow/compiler/xla/client/lib/triangular_solve.cc +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/service/triangular_solve_expander.h" #include #include @@ -33,12 +33,14 @@ limitations under the License. namespace xla { +namespace { + // Get the diagonal blocks of the coefficient matrix XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a)); - int ndims = ShapeUtil::Rank(shape); + int ndims = shape.rank(); int64 n = ShapeUtil::GetDimension(shape, -1); int64 num_blocks = n / block_size; @@ -62,15 +64,26 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { /*broadcast_sizes=*/{2}), /*permutation=*/{1, 0}); + PaddingConfig padding_config = + MakeEdgePaddingConfig({{0, 0}, {ndims - 2, 0}}); + start_indices = + Pad(start_indices, ConstantR0(builder, 0), padding_config); + // Gather the diagonal blocks + std::vector slice_sizes(ndims); GatherDimensionNumbers dim_numbers; + for (int i = 0; i < ndims - 2; ++i) { + dim_numbers.add_offset_dims(i); + dim_numbers.add_start_index_map(i); + slice_sizes[i] = ShapeUtil::GetDimension(shape, i); + } + slice_sizes[ndims - 2] = slice_sizes[ndims - 1] = block_size; dim_numbers.add_offset_dims(ndims - 1); dim_numbers.add_offset_dims(ndims); dim_numbers.add_start_index_map(ndims - 2); dim_numbers.add_start_index_map(ndims - 1); dim_numbers.set_index_vector_dim(1); - diag_blocks = Gather(a, start_indices, dim_numbers, - /*slice_sizes=*/{block_size, block_size}); + diag_blocks = Gather(a, start_indices, dim_numbers, slice_sizes); } // The last block might be smaller than the block size, @@ -129,9 +142,7 @@ XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, // zero (which can happen if the last block was padded) otherwise it will // introduce nans which will propagate auto diags = GetMatrixDiagonal(diag_blocks); - TF_ASSIGN_OR_RETURN(Shape diags_shape, builder->GetShape(diags)); - auto one = ScalarLike(diags, 1); - auto ones = Broadcast(one, AsInt64Slice(diags_shape.dimensions())); + auto ones = FullLike(diags, 1); diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); @@ -154,10 +165,10 @@ XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, // The first or last diagonal element should be set to 1 instead of -1 // though, since we never update it auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); - auto start_index = (lower) ? 0 : block_size - 1; - auto output_block = DynamicUpdateSlice( - neg_identity, pos_one, - /*start_indices=*/ConstantR1(builder, 2, start_index)); + auto start_index = ConstantR0(builder, (lower) ? 0 : block_size - 1); + auto output_block = + DynamicUpdateSlice(neg_identity, pos_one, + /*start_indices=*/{start_index, start_index}); // Broadcast diag([1, -1, -1, ...]) to every block XlaOp output = Broadcast(output_block, @@ -200,12 +211,10 @@ XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, auto body_out = GetTupleElement(input_tuple, 1); auto body_input = GetTupleElement(input_tuple, 2); - auto zero = ConstantR1(bodyb.get(), 1, 0); + auto zero = ConstantR0(bodyb.get(), 0); auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; - auto start_indices = - ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); auto input_row = - DynamicSlice(body_input, start_indices, + DynamicSlice(body_input, {zero, j, zero}, /*slice_sizes=*/{num_blocks, 1, block_size}); // We want -L21 L11^{-1} @@ -219,7 +228,7 @@ XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, precision_proto.add_operand_precision(precision); auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); - body_out = DynamicUpdateSlice(body_out, update, start_indices); + body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero}); auto next_i = i + ScalarLike(i, 1); Tuple(bodyb.get(), {next_i, body_out, body_input}); @@ -251,7 +260,7 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - int64 ndims = ShapeUtil::Rank(a_shape); + int64 ndims = a_shape.rank(); int64 n = ShapeUtil::GetDimension(a_shape, -1); int64 num_blocks = n / block_size + (n % block_size != 0); int64 m_dim = (left_side) ? -1 : -2; @@ -338,20 +347,21 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, }); } -XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, - bool transpose_a, bool conjugate_a, int64 block_size, - PrecisionConfig::Precision precision) { +XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + bool unit_diagonal, int64 block_size, + PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); - if (ShapeUtil::Rank(a_shape) != ShapeUtil::Rank(b_shape)) { + if (a_shape.rank() != b_shape.rank()) { return InvalidArgument( "Arguments to TriangularSolve have shapes with different ranks: " "%s vs. %s", ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } - const int64 ndims = ShapeUtil::Rank(a_shape); + const int64 ndims = a_shape.rank(); if (ndims < 2) { return InvalidArgument( "Arguments to TriangularSolve was rank %d but must have rank >= 2.", @@ -393,6 +403,26 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, block_size); } + if (ShapeUtil::IsZeroElementArray(b_shape)) { + // The output has the same shape as 'b', and since the output has zero + // elements, any such array will do. + return b; + } + + // TODO(phawkins): consider pushing triangle masking into + // InvertDiagonalBlocks. + if (unit_diagonal) { + // Mask everything but the subdiagonal/superdiagonal elements. + a = lower ? Select(TriangleMask(a, -1), a, ZerosLike(a)) + : Select(TriangleMask(a, 0), ZerosLike(a), a); + int64 k = ShapeUtil::GetDimension(a_shape, -1); + a = xla::Add(a, IdentityMatrix(builder, a_shape.element_type(), k, k), + /*broadcast_dimensions=*/{ndims - 2, ndims - 1}); + } else { + // Mask off the ignored elements of the triangular matrix a. + a = Triangle(a, lower); + } + // We find the diagonal blocks of the coefficient matrix auto diag_blocks = DiagonalBlocks(a, block_size); @@ -409,4 +439,66 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, }); } +} // namespace + +bool TriangularSolveExpander::InstructionMatchesPattern( + HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kTriangularSolve; +} + +StatusOr TriangularSolveExpander::ExpandInstruction( + HloInstruction* instruction) { + const TriangularSolveOptions& options = + instruction->triangular_solve_options(); + const string name = absl::StrFormat( + "xla.triangular_solve_%s_%s_%s_%s_%s_%s", + instruction->operand(0)->shape().ToString(), + instruction->operand(1)->shape().ToString(), + options.left_side() ? "left" : "right", + options.lower() ? "lower" : "upper", + TriangularSolveOptions_Transpose_Name(options.transpose_a()), + options.unit_diagonal() ? "unit" : "nonunit"); + + HloModule* module = instruction->parent()->parent(); + + HloComputation*& computation = + computation_cache_.emplace(name, nullptr).first->second; + if (!computation) { + // Builds a new expansion. + // + // We do something unusual here: we build the computation using the + // XlaBuilder API, which is nominally an XLA client API. We do this because + // the external APIs for building complicated computations (XlaBuilder) + // are much more ergonomic than the internal ones. As it turns out, + // XlaBuilder isn't really a client API—what it does is build a + // HloModuleProto protocol buffer, that we can then deserialize and clone + // into our HloModule. Ideally we would avoid the protocol buffer step; + // that is left as an exercise for future work. + XlaBuilder builder(name); + XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a"); + XlaOp b = Parameter(&builder, 1, instruction->operand(1)->shape(), "b"); + bool transpose_a = + options.transpose_a() != TriangularSolveOptions::NO_TRANSPOSE; + bool conjugate_a = options.transpose_a() == TriangularSolveOptions::ADJOINT; + + BuildTriangularSolve(a, b, options.left_side(), options.lower(), + transpose_a, conjugate_a, options.unit_diagonal(), + /*block_size=*/128, + /*precision=*/PrecisionConfig::HIGHEST); + TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); + + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + xla_computation.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( + xla_computation.proto(), config)); + HloCloneContext context(module); + computation = + module->DeepCloneComputation(new_module->entry_computation(), &context); + } + + return instruction->parent()->AddInstruction(HloInstruction::CreateCall( + instruction->shape(), instruction->operands(), computation)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.h b/tensorflow/compiler/xla/service/triangular_solve_expander.h new file mode 100644 index 0000000000000000000000000000000000000000..be2374ef8c86254d8db5ac1acac385aa0de7d3a5 --- /dev/null +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +namespace xla { + +class TriangularSolveExpander : public OpExpanderPass { + public: + absl::string_view name() const override { + return "triangular_solve_expander"; + } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + StatusOr ExpandInstruction( + HloInstruction* instruction) override; + + private: + // Mapping from op signatures to existing computations. + absl::flat_hash_map computation_cache_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 50d51eaeb762e208004c1dae3dcc27503f3f94e9..cc82e9bb0287b5a586fb21fee35d3124a6d6f121 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -55,11 +56,10 @@ bool PointsToSet::IsAmbiguous() const { bool PointsToSet::IsDistinct() const { bool distinct = true; - std::set all_points_to; - ForEachElement([&distinct, &all_points_to](const ShapeIndex& /*index*/, - const BufferList& points_to) { + absl::flat_hash_set all_points_to; + ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) { for (auto& buffer : points_to) { - if (all_points_to.count(buffer) != 0) { + if (all_points_to.contains(buffer)) { distinct = false; } all_points_to.insert(buffer); @@ -87,9 +87,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const { bool found = false; ForEachElement([&found, &buffer](const ShapeIndex& /*index*/, const BufferList& pointed_to_buffers) { - if (!found && - std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(), - &buffer) != pointed_to_buffers.end()) { + if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) { found = true; } }); @@ -99,8 +97,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const { bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer, const ShapeIndex& index) const { const auto& pointed_to_buffers = element(index); - return std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(), - &buffer) != pointed_to_buffers.end(); + return absl::c_linear_search(pointed_to_buffers, &buffer); } void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer, @@ -210,7 +207,7 @@ Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) { &logical_buffer_analysis_->GetBuffer(hlo_instruction, index)); }); - if (ShapeUtil::IsTuple(hlo_instruction->shape())) { + if (hlo_instruction->shape().IsTuple()) { // If the hlo instruction is a tuple-shaped, then trivially the instruction // itself is the source of the tuple. points_to_set.add_tuple_source({}, hlo_instruction); @@ -604,9 +601,8 @@ bool TuplePointsToAnalysis::DoesNotUseOperandBuffer( } else if (user->opcode() == HloOpcode::kFusion && user->fusion_kind() == HloInstruction::FusionKind::kLoop) { // Find fusion parameter associated with 'operand'. - auto it = std::find_if( - user->fused_parameters().begin(), user->fused_parameters().end(), - [=](HloInstruction* fused_param) { + auto it = absl::c_find_if( + user->fused_parameters(), [&](HloInstruction* fused_param) { return user->operand(fused_param->parameter_number()) == operand; }); CHECK(it != user->fused_parameters().end()); @@ -672,9 +668,8 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( } // Find fusion parameter associated with 'operand'. const auto& fused_params = fusion->fused_parameters(); - auto fused_param_it = std::find_if( - fused_params.begin(), fused_params.end(), - [&](HloInstruction* fused_param) { + auto fused_param_it = + absl::c_find_if(fused_params, [&](HloInstruction* fused_param) { return fusion->operand(fused_param->parameter_number()) == operand; }); if (fused_param_it == fused_params.end()) { @@ -704,6 +699,8 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( // (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index // 0. // (5) The 'user' of 'operand' is Sort, and it is the only user. +// (6) The 'user' of 'operand' is TriangularSolve, it is the second operand, +// and it is the only user. // // (2) and (3) can only be determined if points-to analysis is available. bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( @@ -743,11 +740,10 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( // Check if one operand of kAdd fused root is kDot or kConvolution. auto* add = user->fused_expression_root(); auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot; - }); + absl::c_find_if(add->operands(), [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); if (add_operand_it == add->operands().end()) { return false; } @@ -785,6 +781,14 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && user_index[0] == operand_indices[0]; } + if (user->opcode() == HloOpcode::kTriangularSolve) { + // Only valid if there are no other users. + if (operand->users().size() != 1) { + return false; + } + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 1; + } if (user->opcode() == HloOpcode::kCall) { // TODO(b/62548313): Remove when buffer assignment is module scoped and // does not assign buffers to calls. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 561762b5d424ed5f537665be9d67a81dc8bdd56e..6f61fc44166298e86a88dfc4f0ce8526d65ffd02 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/instruction_fusion.h" @@ -623,7 +624,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { void Run(const bool add_additional_gte0_user) { Shape input_shape = ShapeUtil::MakeShape(F32, {8}); Shape update_shape = ShapeUtil::MakeShape(F32, {3}); - Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, update_shape, starts_shape}); @@ -657,7 +658,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { HloInstruction::CreateGetTupleElement(starts_shape, tuple_param0, 2)); // Update 'input' with 'update' at dynamic 'starts' indices. builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - input_shape, input, update, starts)); + input_shape, input, update, {starts})); // Build computation and add it to module as entry computation. BuildModule(builder.Build()); @@ -721,9 +722,8 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { // to fusion 'operand'. HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion, HloInstruction* operand) { - auto it = std::find_if( - fusion->fused_instructions().begin(), - fusion->fused_instructions().end(), [=](const HloInstruction* fused) { + auto it = absl::c_find_if( + fusion->fused_instructions(), [&](const HloInstruction* fused) { return fused->opcode() == HloOpcode::kParameter && fusion->operand(fused->parameter_number()) == operand; }); @@ -734,7 +734,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { // Returns all users of 'fusion_paran' at 'tuple_index'. std::vector GetFusionParameterUsersAt( HloInstruction* fusion_param, int64 tuple_index) { - CHECK(ShapeUtil::IsTuple(fusion_param->shape())); + CHECK(fusion_param->shape().IsTuple()); std::vector users_at_tuple_index; for (auto user : fusion_param->users()) { CHECK_EQ(HloOpcode::kGetTupleElement, user->opcode()); @@ -883,12 +883,12 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, {starts})); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -977,12 +977,12 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, {starts})); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -1004,7 +1004,7 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { Shape data_shape = ShapeUtil::MakeShape(F32, {8}); Shape update_shape = ShapeUtil::MakeShape(F32, {4}); - Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {}); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); auto update = builder.AddInstruction( @@ -1012,7 +1012,7 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { auto starts = builder.AddInstruction( HloInstruction::CreateParameter(2, starts_shape, "starts")); auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, data, update, starts)); + data_shape, data, update, {starts})); BuildModuleAndRunAnalysis(builder.Build()); @@ -1066,14 +1066,17 @@ TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto builder = HloComputation::Builder(TestName()); + module_ = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - auto sort = - builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, + &builder, module_.get())); - BuildModuleAndRunAnalysis(builder.Build()); + computation_ = module_->AddEntryComputation(builder.Build()); + RunAnalysis(); EXPECT_TRUE( points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {})); @@ -1081,6 +1084,7 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto builder = HloComputation::Builder(TestName()); + module_ = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); Shape values_shape = ShapeUtil::MakeShape(F32, {8}); @@ -1088,11 +1092,14 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { HloInstruction::CreateParameter(0, keys_shape, "keys")); auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); - auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, - {values})); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, + MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}), + {keys, values}, 0, /*is_stable=*/false, &builder, + module_.get())); - BuildModuleAndRunAnalysis(builder.Build()); + computation_ = module_->AddEntryComputation(builder.Build()); + RunAnalysis(); // The buffer for the keys can be shared with the first tuple entry. EXPECT_TRUE( diff --git a/tensorflow/compiler/xla/service/tuple_util.cc b/tensorflow/compiler/xla/service/tuple_util.cc index cfb0c787d09557fd1aec3517eb9698cfec323369..90ea79ec263a038556ccbd2cd345b337c5a5dcf3 100644 --- a/tensorflow/compiler/xla/service/tuple_util.cc +++ b/tensorflow/compiler/xla/service/tuple_util.cc @@ -21,7 +21,7 @@ namespace xla { /*static*/ HloInstruction* TupleUtil::ExtractPrefix(HloInstruction* input_tuple, int64 elements) { - CHECK(ShapeUtil::IsTuple(input_tuple->shape())); + CHECK(input_tuple->shape().IsTuple()); HloComputation* computation = input_tuple->parent(); const Shape& input_shape = input_tuple->shape(); @@ -41,7 +41,7 @@ namespace xla { /*static*/ HloInstruction* TupleUtil::AppendSuffix( HloInstruction* input_tuple, absl::Span trailing_values) { - CHECK(ShapeUtil::IsTuple(input_tuple->shape())); + CHECK(input_tuple->shape().IsTuple()); HloComputation* computation = input_tuple->parent(); const Shape& input_shape = input_tuple->shape(); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index 68e2569f66bea9ec1223e454d1ead0efc7b9498e..c93a9ba3176002a34fe84a29e62075de4d19168f 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -301,7 +301,7 @@ optional ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) { /*dest_shape_index=*/{indvar_index}, /*src_shape_index=*/{})); StatusOr eval_result = - evaluator.Evaluate(*while_cond, {std::move(fake_input)}); + evaluator.Evaluate(*while_cond, {std::move(fake_input)}); if (!eval_result.ok()) { VLOG(2) << "Couldn't evaluate while loop condition."; diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 75d406435b6f58faecc86b82c33e9e2dd6bccbea..3bcf5c38309a86e9e3cab3268f3f065005f7a923 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -129,7 +129,7 @@ condition { ENTRY entry { const_0 = f32[2] constant({1, 2}) - const_1 = (f32[2], f32[2]) constant((f32[2], f32[2]) ({2, 1},{3,1})) + const_1 = (f32[2], f32[2]) constant(({2, 1},{3,1})) while_init = (f32[2],(f32[2],f32[2])) tuple(const_0, const_1) ROOT while = (f32[2],(f32[2],f32[2])) while(while_init), condition=condition, body=body } @@ -206,8 +206,8 @@ body { p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 - token = token[] after-all() - outfeed = token[] outfeed(p_body.0, token) + token0 = token[] after-all() + outfeed = token[] outfeed(p_body.0, token0) ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1) } @@ -305,7 +305,7 @@ condition { ENTRY entry { const_0 = f32[] constant(0) - const_1 = (f32[], f32[]) constant((f32[], f32[]) (1, 10)) + const_1 = (f32[], f32[]) constant((1, 10)) while_init = (f32[],(f32[],f32[])) tuple(const_0, const_1) ROOT while = (f32[],(f32[],f32[])) while(while_init), condition=condition, body=body } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 41011176ffa91e885bc58364d1fb19617d3518ad..69cc8feb3f31ad782b9d3437d81d0ab8ce10aadb 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -89,7 +89,7 @@ static void CreateLoopInvariantCopy( HloInstruction* next_operand = frame->instruction->mutable_operand(frame->operand_index++); - if (hoisted_instructions->count(next_operand) || + if (hoisted_instructions->contains(next_operand) || next_operand == while_body_param) { continue; } @@ -127,7 +127,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr) { auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); - if (!ShapeUtil::IsTuple(while_instr->shape())) { + if (!while_instr->shape().IsTuple()) { // This restriction leaves one interesting pattern on the table: // // while_body(f32[1024, 1024] %param) { @@ -168,7 +168,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( // is no benefit to hoisting them unless something that uses it is also // hoisted. for (auto* instr : WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { - if (ShapeUtil::IsArray(instr->shape())) { + if (instr->shape().IsArray()) { // TODO(b/79147885): We should try to generalize this to tuples for // uniformity's sake, if nothing else. InsertOrDie(&unhoisted_invariant_instructions, instr); @@ -221,7 +221,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( ShapeUtil::ForEachSubshape( operand->shape(), [&input_size](const Shape& subshape, const ShapeIndex& /*index*/) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { input_size += ShapeUtil::ByteSizeOfElements(subshape); } }); @@ -229,7 +229,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( ShapeUtil::ForEachSubshape( instruction->shape(), [&output_size](const Shape& subshape, const ShapeIndex& /*index*/) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { output_size += ShapeUtil::ByteSizeOfElements(subshape); } }); @@ -241,7 +241,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( auto is_invariant = [&](HloInstruction* op) { return hoisted_instructions.find(op) != hoisted_instructions.end() || - unhoisted_invariant_instructions.count(op) || + unhoisted_invariant_instructions.contains(op) || op->opcode() == HloOpcode::kConstant; }; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 8e7c4bc8828552e197b41f874c070d496b85a382..3587c016b4420163a607422b1acc838646fab83a 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -299,7 +299,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { // bitcast either. auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); + auto effective_scalar_s32 = ShapeUtil::MakeShape(S32, {1}); auto token_shape = ShapeUtil::MakeTokenShape(); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape}); @@ -314,10 +314,12 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); HloInstruction* in_token = builder.AddInstruction( HloInstruction::CreateGetTupleElement(token_shape, param, 2)); - HloInstruction* bitcast_inst = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); - HloInstruction* out_token = builder.AddInstruction( - HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, in_token, "")); + HloInstruction* bitcast_inst = + builder.AddInstruction(HloInstruction::CreateUnary( + effective_scalar_s32, HloOpcode::kBitcast, gte_0)); + HloInstruction* out_token = + builder.AddInstruction(HloInstruction::CreateOutfeed( + effective_scalar_s32, bitcast_inst, in_token, "")); builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, out_token})); @@ -352,9 +354,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { // The bitcast's user can be hoisted, so hoist the bitcast too. auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); - Shape while_shape = - ShapeUtil::MakeTupleShape({scalar_s32, scalar_f32, scalar_f32}); + auto effective_scalar_s32 = ShapeUtil::MakeShape(S32, {1}); + Shape while_shape = ShapeUtil::MakeTupleShape( + {scalar_s32, effective_scalar_s32, effective_scalar_s32}); HloComputation* while_body = [&]() { HloComputation::Builder builder(TestName() + ".while_body"); @@ -363,12 +365,13 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { HloInstruction* gte_0 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); HloInstruction* gte_1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_f32, param, 1)); - HloInstruction* bitcast_inst = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + HloInstruction::CreateGetTupleElement(effective_scalar_s32, param, 1)); + HloInstruction* bitcast_inst = + builder.AddInstruction(HloInstruction::CreateUnary( + effective_scalar_s32, HloOpcode::kBitcast, gte_0)); HloInstruction* add_inst = builder.AddInstruction(HloInstruction::CreateBinary( - scalar_f32, HloOpcode::kAdd, bitcast_inst, gte_1)); + effective_scalar_s32, HloOpcode::kAdd, bitcast_inst, gte_1)); builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_inst})); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index d30f67dd8110b88166fe807762fb653190ec00bc..386ffb995477ff1b4aef73080b6a6fd988dd1980 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -58,7 +58,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { HloComputation* while_body = while_op->while_body(); HloInstruction* while_body_root = while_body->root_instruction(); - if (!ShapeUtil::IsTuple(while_init->shape())) { + if (!while_init->shape().IsTuple()) { VLOG(2) << "While op's carried value isn't tuple shaped."; return false; } @@ -109,8 +109,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // operand appears in, but it may appear more than once! if (user->user_count() == 1 && user->users().front() == while_body_root && while_body_root->operand_index(user) == user->tuple_index() && - std::count(while_body_root->operands().begin(), - while_body_root->operands().end(), user) == 1) { + absl::c_count(while_body_root->operands(), user) == 1) { continue; } @@ -127,7 +126,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // through to the while body's root, count that element as "used", since // removing that element would be observable. for (int64 i = 0; i < while_body_root->operand_count(); ++i) { - if (used_tuple_indices.count(i)) { + if (used_tuple_indices.contains(i)) { continue; } @@ -158,7 +157,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Build up maps from the old/new to the new/old tuple indices. std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), used_tuple_indices.end()); - std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end()); + absl::c_sort(new_to_old_tuple_idx); absl::flat_hash_map old_to_new_tuple_idx; for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { @@ -181,7 +180,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // replace the old instructions after we remove unused elements from the while // tuple. auto make_while_computation_replacements = [&](const HloComputation* comp) { - std::unordered_map> + absl::flat_hash_map> replacements; auto* param = comp->parameter_instruction(0); @@ -233,7 +232,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { while_cond->CloneWithReplacements( make_while_computation_replacements(while_cond)); - std::unordered_map> + absl::flat_hash_map> while_body_replacements = make_while_computation_replacements(while_body); std::vector new_while_body_root_elems; new_while_body_root_elems.reserve(new_to_old_tuple_idx.size()); @@ -583,8 +582,7 @@ static StatusOr TryPropagateConstant(HloInstruction* while_op) { static std::unique_ptr UnflattenTupleInstr( absl::Span instrs, const Shape& desired_shape, std::vector>* new_instrs) { - CHECK(ShapeUtil::IsTuple(desired_shape)) - << ShapeUtil::HumanString(desired_shape); + CHECK(desired_shape.IsTuple()) << ShapeUtil::HumanString(desired_shape); // For each child shape in `desired_shape`, slice out the correct number of // `instrs` and call UnflattenTupleInstr recursively. At each step we remove @@ -593,7 +591,7 @@ static std::unique_ptr UnflattenTupleInstr( std::vector elems; for (int64 i = 0; i < desired_shape.tuple_shapes_size(); ++i) { const Shape& subshape = desired_shape.tuple_shapes(i); - if (!ShapeUtil::IsTuple(subshape)) { + if (!subshape.IsTuple()) { elems.push_back(instrs[0]); instrs.remove_prefix(1); continue; @@ -603,7 +601,7 @@ static std::unique_ptr UnflattenTupleInstr( int64 num_leaves = 0; ShapeUtil::ForEachSubshape( subshape, [&](const Shape& s, const ShapeIndex& /*index*/) { - if (!ShapeUtil::IsTuple(s)) { + if (!s.IsTuple()) { ++num_leaves; } }); @@ -625,7 +623,7 @@ static std::vector GetFlatTupleElems( HloInstruction* instr, std::vector>* new_instrs) { const auto& shape = instr->shape(); - if (!ShapeUtil::IsTuple(shape)) { + if (!shape.IsTuple()) { return {instr}; } std::vector elems; @@ -665,7 +663,7 @@ static StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { std::vector flattened_shape_elems; ShapeUtil::ForEachSubshape(while_shape, [&](const Shape& s, const ShapeIndex& /*index*/) { - if (!ShapeUtil::IsTuple(s)) { + if (!s.IsTuple()) { flattened_shape_elems.push_back(s); } }); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 4950e8269e9cf0723d717bd1734518d104c0c9f2..ecca76b1e86d833c73fbb9bad6a341660a7d2669 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -406,13 +407,12 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { // The original while instruction is still left in the module as a dead // instruction, find a while instruction with a different name as the new // while instruction. + const auto& instrs = m->entry_computation()->instructions(); HloInstruction* new_while_op = - *std::find_if(m->entry_computation()->instructions().begin(), - m->entry_computation()->instructions().end(), - [&](const HloInstruction* instr) { - return (instr->opcode() == HloOpcode::kWhile && - instr->name() != "while"); - }); + *absl::c_find_if(instrs, [&](const HloInstruction* instr) { + return (instr->opcode() == HloOpcode::kWhile && + instr->name() != "while"); + }); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); EXPECT_TRUE( @@ -554,8 +554,7 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { HloInstruction* new_while = FindFirstWhile(m.get()); Shape flat_tuple = - ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3], s32[4])") - .ValueOrDie(); + ParseShape("(s32[1], s32[2], s32[3], s32[4])").ValueOrDie(); SCOPED_TRACE(m->ToString()); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), flat_tuple)); EXPECT_TRUE(ShapeUtil::Equal( @@ -567,8 +566,7 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { flat_tuple)); EXPECT_TRUE(ShapeUtil::Equal( m->entry_computation()->root_instruction()->shape(), - ShapeUtil::ParseShapeString("((s32[1]), (s32[2], s32[3], (s32[4])))") - .ValueOrDie())); + ParseShape("((s32[1]), (s32[2], s32[3], (s32[4])))").ValueOrDie())); } // Edge-case: All elements of the loop carry are constants which can be removed, @@ -641,8 +639,7 @@ TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) { EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); HloInstruction* new_while = FindFirstWhile(m.get()); - Shape new_while_shape = - ShapeUtil::ParseShapeString("(s32[1], s32[3])").ValueOrDie(); + Shape new_while_shape = ParseShape("(s32[1], s32[3])").ValueOrDie(); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); EXPECT_TRUE(ShapeUtil::Equal( new_while->while_body()->root_instruction()->shape(), new_while_shape)); @@ -652,9 +649,9 @@ TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) { EXPECT_TRUE(ShapeUtil::Equal( new_while->while_condition()->parameter_instruction(0)->shape(), new_while_shape)); - EXPECT_TRUE(ShapeUtil::Equal( - m->entry_computation()->root_instruction()->shape(), - ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3])").ValueOrDie())); + EXPECT_TRUE( + ShapeUtil::Equal(m->entry_computation()->root_instruction()->shape(), + ParseShape("(s32[1], s32[2], s32[3])").ValueOrDie())); EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(_, op::Constant(), _)); } @@ -712,7 +709,7 @@ TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_Simple) { // We should have added a new loop counter for s32[] to the end of the tuple. SCOPED_TRACE(m->ToString()); Shape new_while_shape = - ShapeUtil::ParseShapeString("(s32[], s32[], s32[], s32[])").ValueOrDie(); + ParseShape("(s32[], s32[], s32[], s32[])").ValueOrDie(); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); EXPECT_TRUE(ShapeUtil::Equal( new_while->while_body()->root_instruction()->shape(), new_while_shape)); diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 039ccda7322f5efda6a827efbeda1225c3596cc0..d77386497a14b3e52be2ea7f655fa330f60e4a97 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -97,7 +97,7 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { WhileUtil::MakeInstructionsLiveIn( HloInstruction* while_instr, absl::Span instructions) { - CHECK(ShapeUtil::IsTuple(while_instr->shape())); + CHECK(while_instr->shape().IsTuple()); int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size(); Shape new_while_shape = while_instr->shape(); diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index 5e6941933330fde29bc9c779aae4bb3c36914660..d92b9870f373564ae8fd904c8bf9f0d1afbff9c4 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -180,8 +180,8 @@ body { cond { param.c = (s32[], s32[]) parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT condition = pred[] get-tuple-element(infeed), index=0 } diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc index 83d696fe0915086c3c98b6d7cbdaeaeb4d9d0bdb..661b7aa7d99ca549da6a509812760a1665d60919 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc @@ -31,16 +31,21 @@ StatusOr ZeroSizedHloElimination::Run(HloModule* module) { bool changed = false; for (HloComputation* comp : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { - if (instruction->HasSideEffect() || - !ShapeUtil::IsArray(instruction->shape()) || + if (instruction->HasSideEffect() || !instruction->shape().IsArray() || instruction->opcode() == HloOpcode::kConstant) { continue; } if (comp->IsRemovable(instruction) && ShapeUtil::IsZeroElementArray(instruction->shape())) { + // If the instruction doesn't have a layout, use a default layout for + // the literal. + Shape shape = instruction->shape(); + if (!LayoutUtil::HasLayout(shape)) { + LayoutUtil::SetToDefaultLayout(&shape); + } TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( - instruction, HloInstruction::CreateConstant( - Literal::CreateFromShape(instruction->shape())))); + instruction, + HloInstruction::CreateConstant(Literal::CreateFromShape(shape)))); changed = true; } } diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index a546a6d39cc55d1f327b8449c7d26cd4c95dbf98..572a79609e7a912277af0fd2ba43f9a1e14a6f52 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -82,5 +82,18 @@ TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateConstant) { EXPECT_FALSE(changed); } +TEST_F(ZeroSizedHloEliminationTest, ZeroSizedInstructionWithoutLayoutFolded) { + Shape op_shape = ShapeUtil::MakeShape(F32, {4, 0}); + op_shape.clear_layout(); + HloInstruction* param1 = builder_.AddInstruction( + HloInstruction::CreateParameter(1, op_shape, "zero sized param 1")); + HloInstruction* param2 = builder_.AddInstruction( + HloInstruction::CreateParameter(2, op_shape, "zero sized param 2")); + builder_.AddInstruction( + HloInstruction::CreateBinary(op_shape, HloOpcode::kAdd, param1, param2)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); + EXPECT_TRUE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index 746ab9e9977b1b10cdb0cb57197027d65bd50f55..94854047e530babe2234381a615aeb805f0d5933 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -27,12 +27,31 @@ Shape::Shape(const ShapeProto& shape_proto) { for (const int64 dimension : shape_proto.dimensions()) { add_dimensions(dimension); } + // A malformed proto may have different is_dynamic_dimension_size and + // dimensions_size. Since C++ is evil, and we have no good way of bailing out + // in a constructor, conservatively trim the is_dynamic_dimension size. + // TODO(b/120111794): Make this a hard error when we have a factory method + // instead of a constructor. + if (shape_proto.dimensions_size() != + shape_proto.is_dynamic_dimension_size()) { + if (shape_proto.is_dynamic_dimension_size() != 0) { + LOG(ERROR) << "Malformed shape proto: number of is_dynamic_dimension " + "fields does not match number of dimension fields"; + } else { + LOG(WARNING) << "Malformed shape proto: is_dynamic_dimension is empty"; + } + } + int64 num_dynamic_dimension_fields = std::min( + shape_proto.dimensions_size(), shape_proto.is_dynamic_dimension_size()); + for (int i = 0; i < num_dynamic_dimension_fields; i++) { + dynamic_dimensions_[i] = shape_proto.is_dynamic_dimension(i); + } tuple_shapes_.reserve(shape_proto.tuple_shapes_size()); for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) { *add_tuple_shapes() = Shape(element_shape); } if (shape_proto.has_layout()) { - *mutable_layout() = shape_proto.layout(); + *mutable_layout() = Layout::CreateFromProto(shape_proto.layout()); } } @@ -43,12 +62,15 @@ ShapeProto Shape::ToProto() const { for (const int64 dimension : dimensions()) { proto.add_dimensions(dimension); } + for (const bool dynamic : dynamic_dimensions_) { + proto.add_is_dynamic_dimension(dynamic); + } proto.mutable_tuple_shapes()->Reserve(tuple_shapes_size()); for (const Shape& shape : tuple_shapes()) { *proto.add_tuple_shapes() = shape.ToProto(); } if (has_layout()) { - *proto.mutable_layout() = layout(); + *proto.mutable_layout() = layout().ToProto(); } return proto; } @@ -61,6 +83,101 @@ string Shape::ToString(bool print_layout) const { } } +bool Shape::is_static() const { + if (IsTuple()) { + for (const Shape& subshape : tuple_shapes_) { + if (!subshape.is_static()) { + return false; + } + } + } + return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; }); +} + +void Shape::DeleteDimension(int64 dim_to_delete) { + CHECK(IsArray()); + CHECK_GE(dim_to_delete, 0); + CHECK_LT(dim_to_delete, dimensions_.size()); + dimensions_.erase(dimensions_.begin() + dim_to_delete); + dynamic_dimensions_.erase(dynamic_dimensions_.begin() + dim_to_delete); + if (LayoutUtil::HasLayout(*this)) { + layout_.set_format(DENSE); + for (int64 i = 0; i < layout_.minor_to_major().size();) { + if (layout_.minor_to_major(i) == dim_to_delete) { + layout_.mutable_minor_to_major()->erase( + layout_.mutable_minor_to_major()->begin() + i); + continue; + } + if (layout_.minor_to_major(i) > dim_to_delete) { + (*layout_.mutable_minor_to_major())[i] -= 1; + } + ++i; + } + } +} + +bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { + if (lhs.IsTuple()) { + return rhs.IsTuple() && + absl::c_equal( + lhs.tuple_shapes(), rhs.tuple_shapes(), + [=](const Shape& l, const Shape& r) { return (*this)(l, r); }); + } else if (!lhs.IsArray()) { + // Non-tuple, non-array tupes such as opaque and token types are trivially + // the same. + return lhs.element_type() == rhs.element_type(); + } + + if (!rhs.IsArray()) { + return false; + } + + if (!ignore_element_type_) { + if ((ignore_fp_precision_ && + !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) || + (!ignore_fp_precision_ && !ShapeUtil::SameElementType(lhs, rhs))) { + VLOG(3) << "CompareShapes: lhs element type != rhs element type"; + return false; + } + } + + if (!ignore_layout_) { + if (lhs.layout().format() != rhs.layout().format()) { + VLOG(3) << "CompareShapes: lhs layout format != rhs layout format"; + return false; + } + if (LayoutUtil::IsDenseArray(lhs)) { + Layout::Equal equal; + if (ignore_tiles_in_layout_) { + equal.IgnoreTiles(); + } + if (ignore_element_size_in_layout_) { + equal.IgnoreElementSize(); + } + if (!equal(lhs.layout(), rhs.layout())) { + VLOG(3) << "CompareShapes: lhs layout != rhs layout"; + return false; + } + } + } + + if (!ShapeUtil::SameDimensions(lhs, rhs)) { + VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; + return false; + } + + if (!ignore_dynamic_dimension_) { + for (int i = 0; i < lhs.rank(); ++i) { + if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) { + VLOG(3) + << "CompareShapes: lhs and rhs have different dynamic dimensions."; + return false; + } + } + } + return true; +} + std::ostream& operator<<(std::ostream& out, const Shape& shape) { out << shape.ToString(/*print_layout=*/true); return out; diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 7f6b14ab4286c696dce64d2250a3fe8a57e4865b..78cea83c6d71e5965f10cd3a917ffccabd630462 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -20,6 +20,8 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" @@ -43,6 +45,43 @@ class Shape { // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]". string ToString(bool print_layout = false) const; + // Returns the rank (number of dimensions) of the given shape. Shape must be + // an array. + int64 rank() const { + CHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString(); + return dimensions_.size(); + } + + // Returns whether the shape is of the specified type (array, tuple, etc). + bool IsArray() const { return primitive_util::IsArrayType(element_type()); } + bool IsTuple() const { return element_type() == TUPLE; } + bool IsToken() const { return element_type() == TOKEN; } + bool IsOpaque() const { return element_type() == OPAQUE; } + + // Returns true if no array dimension in the shape is dynamically sized. Tuple + // shapes are traversed recursively. + bool is_static() const; + + // Returns true if the given dimension is dynamically-sized. + bool is_dynamic_dimension(int dimension) const { + return dynamic_dimensions_.at(dimension); + } + + // Sets whether or not the given dimension is dynamically-sized. + void set_dynamic_dimension(int dimension, bool is_dynamic) { + dynamic_dimensions_[dimension] = is_dynamic; + } + + const std::vector& dynamic_dimensions() const { + return dynamic_dimensions_; + } + + // Add dimension_upper_bound(). + + // Removes the given dimension form the shape. Layout, if it exists, is + // adjusted to match the modified shape. + void DeleteDimension(int64 dim_to_delete); + // The following methods mirror the protobuf generated code interface for the // message ShapeProto. This enabled easy migration of this data structure // from a proto to a proper C++ class. @@ -57,10 +96,16 @@ class Shape { int dimensions_size() const { return dimensions_.size(); } int64 dimensions(int index) const { return dimensions_.at(index); } void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; } - void add_dimensions(int64 value) { dimensions_.push_back(value); } - void clear_dimensions() { dimensions_.clear(); } + void add_dimensions(int64 value) { + dimensions_.push_back(value); + dynamic_dimensions_.push_back(false); + } + void clear_dimensions() { + dimensions_.clear(); + dynamic_dimensions_.clear(); + } const std::vector& dimensions() const { return dimensions_; } - std::vector* mutable_dimensions() { return &dimensions_; } + absl::Span mutable_dimensions() { return absl::MakeSpan(dimensions_); } // Methods for accessing the tuple subshapes. This field only non-empty for // tuple shapes. @@ -76,21 +121,10 @@ class Shape { std::vector* mutable_tuple_shapes() { return &tuple_shapes_; } // Methods for accessing the layout field. - bool has_layout() const { return layout_.has_value(); } - const Layout& layout() const { - if (layout_.has_value()) { - return *layout_; - } else { - return Layout::default_instance(); - } - } - Layout* mutable_layout() { - if (!layout_.has_value()) { - layout_ = Layout(); - } - return &layout_.value(); - } - void clear_layout() { layout_.reset(); } + bool has_layout() const { return layout_.format() != INVALID_FORMAT; } + const Layout& layout() const { return layout_; } + Layout* mutable_layout() { return &layout_; } + void clear_layout() { layout_.Clear(); } void Swap(Shape* other) { using std::swap; @@ -101,25 +135,84 @@ class Shape { element_type_ = PRIMITIVE_TYPE_INVALID; dimensions_.clear(); tuple_shapes_.clear(); - layout_.reset(); + clear_layout(); } string SerializeAsString() const { return ToProto().SerializeAsString(); } string ShortDebugString() const { return ToProto().ShortDebugString(); } string DebugString() const { return ToProto().DebugString(); } - public: + // Equal is a configurable functor to check the equality of two shapes. + // + // Examples: + // + // - Comparing two shapes ignoring their layout difference: + // Equal().IgnoreLayout()(shape1, shape2); + // + // - Comparing two shapes ignoring their layout and element type difference: + // Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2); + class Equal { + public: + Equal() = default; + + bool operator()(const Shape& lhs, const Shape& rhs); + + Equal& IgnoreLayout() { + ignore_layout_ = true; + return *this; + } + Equal& IgnoreTilesInLayout() { + ignore_tiles_in_layout_ = true; + return *this; + } + Equal& IgnoreElementSizeInLayout() { + ignore_element_size_in_layout_ = true; + return *this; + } + Equal& IgnoreElementType() { + ignore_element_type_ = true; + return *this; + } + Equal& IgnoreFpPrecision() { + ignore_fp_precision_ = true; + return *this; + } + Equal& IgnoreDynamicDimension() { + ignore_dynamic_dimension_ = true; + return *this; + } + + private: + bool ignore_layout_ = false; + bool ignore_tiles_in_layout_ = false; + bool ignore_element_size_in_layout_ = false; + bool ignore_element_type_ = false; + bool ignore_fp_precision_ = false; + bool ignore_dynamic_dimension_ = false; + }; + + // Test that all fields of the shape are the same, equivalent to Equal(). + bool operator==(const Shape& other) const { return Equal()(*this, other); } + bool operator!=(const Shape& other) const { return !(*this == other); } + + private: // The element type of this shape (tuple, array, etc). PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; - // The array bounds of the dimensions. This is nonempty only for array shapes. + // The array bounds of the dimensions. This is nonempty only for array + // shapes. For a dynamically-sized dimension, the respective value in this + // vector is an inclusive upper limit of the array bound. std::vector dimensions_; + // This vector is the same size as 'dimensions_' and indicates whether the + // respective dimension is dynamically sized. + std::vector dynamic_dimensions_; + // The tuple element subshapes. This is nonempty only for tuple shapes. std::vector tuple_shapes_; - // The array layout of the shape. This is present only for array shapes. - absl::optional layout_; + // The layout of the shape. Only relevant for arrays. + Layout layout_; }; // Shape of the parameters and output of an XLA computation. This is analogous diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index d44db89d571891ecef554cd45c050017833982bb..a000886d60d06a4a598910c901accb6dfd0a8f1a 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -52,7 +52,7 @@ bool ShapeLayout::MatchesLayoutInShape(const Shape& shape) const { const Layout& ShapeLayout::layout() const { CHECK(LayoutIsSet()); - CHECK(!ShapeUtil::IsTuple(shape_)); + CHECK(!shape_.IsTuple()); return shape_.layout(); } @@ -61,15 +61,15 @@ void ShapeLayout::Clear() { LayoutUtil::ClearLayout(&shape_); } bool ShapeLayout::LayoutIsSet() const { return LayoutUtil::HasLayout(shape_); } void ShapeLayout::ResetLayout(const Layout& layout) { - CHECK(!ShapeUtil::IsTuple(shape_)); - CHECK(!ShapeUtil::IsOpaque(shape_)); + CHECK(!shape_.IsTuple()); + CHECK(!shape_.IsOpaque()); *shape_.mutable_layout() = layout; TF_CHECK_OK(ShapeUtil::ValidateShape(shape_)); } void ShapeLayout::ResetLayout(const Layout& layout, ShapeIndexView shape_index) { - CHECK(ShapeUtil::IsTuple(shape_)); + CHECK(shape_.IsTuple()); *ShapeUtil::GetMutableSubshape(&shape_, shape_index)->mutable_layout() = layout; TF_CHECK_OK(ShapeUtil::ValidateShape(shape_)); diff --git a/tensorflow/compiler/xla/shape_test.cc b/tensorflow/compiler/xla/shape_test.cc index e396897eeebc2e7bdc2dc49300c8906710608b05..526abafea5cc244418a4ec05db7da6203716b483 100644 --- a/tensorflow/compiler/xla/shape_test.cc +++ b/tensorflow/compiler/xla/shape_test.cc @@ -41,11 +41,13 @@ class ShapeTest : public ::testing::Test { ShapeUtil::MakeTupleShape({opaque_, scalar_, matrix_, matrix2_}); const Shape nested_tuple_ = ShapeUtil::MakeTupleShape({tuple_, matrix_, token_}); + const Shape dyanmic_matrix_ = + ShapeUtil::MakeShape(S32, {5, 2}, {true, false}); }; TEST_F(ShapeTest, ShapeToFromProto) { - for (const Shape& shape : - {opaque_, token_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_}) { + for (const Shape& shape : {opaque_, token_, scalar_, matrix_, matrix2_, + tuple_, nested_tuple_, dyanmic_matrix_}) { Shape shape_copy(shape.ToProto()); EXPECT_TRUE(ShapeUtil::Equal(shape, shape_copy)) << shape << " != " << shape_copy; @@ -74,6 +76,65 @@ TEST_F(ShapeTest, ShapeToString) { nested_tuple_.ToString(/*print_layout=*/true)); } +TEST_F(ShapeTest, DynamicShapeToString) { + Shape array_shape = + ShapeUtil::MakeShape(F32, {23, 44, 55}, {true, false, true}); + EXPECT_EQ("f32[<=23,44,<=55]", array_shape.ToString()); + + array_shape.set_dynamic_dimension(2, false); + EXPECT_EQ("f32[<=23,44,55]", array_shape.ToString()); +} + +TEST_F(ShapeTest, EqualityTest) { + // Different layouts. + EXPECT_NE(ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {0, 1})); + + // Different dims. + EXPECT_NE(ShapeUtil::MakeShapeWithLayout(F32, {44, 23}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0})); + + // Different elements. + EXPECT_NE(ShapeUtil::MakeShapeWithLayout(S32, {44, 23}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0})); + + // Equal shapes. + EXPECT_EQ(ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0})); +} + +TEST_F(ShapeTest, IsStatic) { + EXPECT_TRUE(opaque_.is_static()); + EXPECT_TRUE(token_.is_static()); + EXPECT_TRUE(matrix_.is_static()); + EXPECT_TRUE(tuple_.is_static()); + EXPECT_TRUE(nested_tuple_.is_static()); + + Shape dynamic_matrix = matrix_; + EXPECT_TRUE(dynamic_matrix.is_static()); + dynamic_matrix.set_dynamic_dimension(1, true); + EXPECT_FALSE(dynamic_matrix.is_static()); + + Shape dynamic_tuple = tuple_; + EXPECT_TRUE(dynamic_tuple.is_static()); + ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2}) + ->set_dynamic_dimension(1, true); + EXPECT_FALSE(dynamic_tuple.is_static()); +} + +TEST_F(ShapeTest, IsDynamicDimension) { + Shape dynamic_matrix = matrix_; + dynamic_matrix.set_dynamic_dimension(1, true); + EXPECT_FALSE(dynamic_matrix.is_dynamic_dimension(0)); + EXPECT_TRUE(dynamic_matrix.is_dynamic_dimension(1)); + + Shape dynamic_tuple = tuple_; + EXPECT_TRUE(dynamic_tuple.is_static()); + ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2}) + ->set_dynamic_dimension(1, true); + EXPECT_FALSE(dynamic_tuple.is_static()); +} + TEST_F(ShapeTest, ProgramShapeToFromProto) { ProgramShape program_shape; *program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3}); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 7bf97729165bef98fabc29040e02203eee68a53c..089120179e2a77518eb5b18c11a35670b03e9b77 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -395,7 +395,7 @@ class ShapeTreeIterator template int64 ShapeTree::CountSubshapes(const Shape& shape) { int64 current_count = 1; - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { int64 count = ShapeUtil::TupleElementCount(shape); for (int i = 0; i < count; ++i) { current_count += CountSubshapes(shape.tuple_shapes(i)); @@ -407,7 +407,7 @@ int64 ShapeTree::CountSubshapes(const Shape& shape) { template void ShapeTree::InitChildren(const Shape& shape, const T& init_value, Node* node, Index* index) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { const int64 size = ShapeUtil::TupleElementCount(shape); #ifndef NDEBUG index->children_count = size; @@ -443,7 +443,7 @@ void ShapeTree::InitChildren(const Shape& shape, const T& init_value, template void ShapeTree::InitChildren(const Shape& shape, Node* node, Index* index) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { const int64 size = ShapeUtil::TupleElementCount(shape); #ifndef NDEBUG index->children_count = size; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index a4d4e1e53e727bdf7822cacaa4559fcae59d4eae..d045fc7a9e291258640eca75166e116cf7390a7b 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -81,78 +82,16 @@ bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const { /* static */ bool ShapeUtil::IsArrayPrimitiveType( PrimitiveType primitive_type) { - return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && - primitive_type != OPAQUE && primitive_type != TOKEN; + return primitive_util::IsArrayType(primitive_type); } namespace { - -// Recursive helper for comparing the equality of two shapes. Returns true if -// the shapes are the same. If compare_layouts is true, then layouts must also -// match. -bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, - bool ignore_fp_precision) { - if ((ignore_fp_precision && - !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) || - (!ignore_fp_precision && !ShapeUtil::SameElementType(lhs, rhs))) { - VLOG(3) << "CompareShapes: lhs element type != rhs element type"; - return false; - } - - if (ShapeUtil::IsTuple(lhs)) { - return absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), - [=](const Shape& l, const Shape& r) { - return CompareShapes(l, r, compare_layouts, - ignore_fp_precision); - }); - } else if (!ShapeUtil::IsArray(lhs)) { - // Non-tuple, non-array tupes such as opaque and token types are trivially - // the same. - return true; - } - - if (compare_layouts) { - if (lhs.layout().format() != rhs.layout().format()) { - return false; - } - if (LayoutUtil::IsDenseArray(lhs)) { - if (!absl::c_equal(LayoutUtil::MinorToMajor(lhs), - LayoutUtil::MinorToMajor(rhs))) { - VLOG(3) << "CompareShapes: lhs layout != rhs layout"; - return false; - } - - const auto& lhs_tiles = lhs.layout().tiles(); - const auto& rhs_tiles = rhs.layout().tiles(); - if (lhs_tiles.size() != rhs_tiles.size()) { - return false; - } - for (int64 i = 0; i < lhs_tiles.size(); i++) { - if (!absl::c_equal(lhs_tiles[i].dimensions(), - rhs_tiles[i].dimensions())) { - return false; - } - } - - if (lhs.layout().element_size_in_bits() != - rhs.layout().element_size_in_bits()) { - return false; - } - } - } - - if (!ShapeUtil::SameDimensions(lhs, rhs)) { - VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; - return false; - } - return true; -} - // Constructs and returns the new shape with the given minor_to_major order in // its Layout. StatusOr MakeShapeWithLayoutInternal( PrimitiveType element_type, absl::Span dimensions, - absl::Span minor_to_major) { + absl::Span minor_to_major, absl::Span tiles, + int64 element_size_in_bits) { if (dimensions.size() != minor_to_major.size()) { return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", dimensions.size(), minor_to_major.size()); @@ -163,23 +102,19 @@ StatusOr MakeShapeWithLayoutInternal( } TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::MakeValidatedShape(element_type, dimensions)); - auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); - min2maj->Clear(); - for (int64 value : minor_to_major) { - min2maj->Add(value); - } + *shape.mutable_layout() = + LayoutUtil::MakeLayout(minor_to_major, tiles, element_size_in_bits); if (!shape.has_layout()) { return InvalidArgument("Shape has no layout."); } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); return shape; } - } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { - bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, - /*ignore_fp_precision=*/false); + bool equal = Shape::Equal()(lhs, rhs); + if (!equal && VLOG_IS_ON(3)) { VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); @@ -190,8 +125,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, - /*ignore_fp_precision=*/true); + bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs); if (!equal && VLOG_IS_ON(3)) { VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = " << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); @@ -200,12 +134,6 @@ StatusOr MakeShapeWithLayoutInternal( return equal; } -/* static */ int64 ShapeUtil::Rank(const Shape& shape) { - CHECK(ShapeUtil::IsArray(shape)) - << "Non-arrays do not have a rank, shape: " << shape; - return shape.dimensions_size(); -} - /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) { int64 accum = 0; for (int64 dimension : shape.dimensions()) { @@ -232,18 +160,38 @@ StatusOr MakeShapeWithLayoutInternal( return MakeValidatedShape(element_type, dimensions).ValueOrDie(); } +/* static */ Shape ShapeUtil::MakeShape( + PrimitiveType element_type, absl::Span dimensions, + const std::vector& dynamic_dimensions) { + return MakeValidatedShape(element_type, dimensions, dynamic_dimensions) + .ValueOrDie(); +} + /* static */ StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions) { - CHECK(IsArrayPrimitiveType(element_type)); + CHECK(IsArrayPrimitiveType(element_type)) << element_type; Shape result; TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result)); return result; } +/* static */ StatusOr ShapeUtil::MakeValidatedShape( + PrimitiveType element_type, absl::Span dimensions, + const std::vector& dynamic_dimensions) { + TF_ASSIGN_OR_RETURN(Shape shape, + MakeValidatedShape(element_type, dimensions)); + for (int i = 0; i < dynamic_dimensions.size(); ++i) { + shape.set_dynamic_dimension(i, dynamic_dimensions[i]); + } + return shape; +} + /* static */ Shape ShapeUtil::MakeShapeWithLayout( PrimitiveType element_type, absl::Span dimensions, - absl::Span minor_to_major) { - return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major) + absl::Span minor_to_major, absl::Span tiles, + int64 element_size_in_bits) { + return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major, + tiles, element_size_in_bits) .ValueOrDie(); } @@ -319,7 +267,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { CHECK(LayoutUtil::IsDenseArray(*shape)); - shape->mutable_layout()->add_minor_to_major(Rank(*shape)); + shape->mutable_layout()->add_minor_to_major(shape->rank()); shape->add_dimensions(bound); TF_DCHECK_OK(ValidateShape(*shape)); } @@ -334,7 +282,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) { - if (!IsArray(shape)) { + if (!shape.IsArray()) { return false; } return primitive_util::BitWidth(shape.element_type()) == bits; @@ -358,6 +306,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( case U32: case U64: case C64: + case C128: case TUPLE: case OPAQUE: case TOKEN: @@ -376,27 +325,24 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return primitive_util::IsFloatingPointType(shape.element_type()); } -/* static */ bool ShapeUtil::IsArray(const Shape& shape) { - return IsArrayPrimitiveType(shape.element_type()); -} - /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) { - return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(), - shape.tuple_shapes().end(), IsTuple); + return shape.IsTuple() && + absl::c_any_of(shape.tuple_shapes(), + [](const Shape& s) { return s.IsTuple(); }); } /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) { - return IsTuple(shape) && TupleElementCount(shape) == 0; + return shape.IsTuple() && TupleElementCount(shape) == 0; } /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { - CHECK(IsTuple(shape)) << HumanString(shape); + CHECK(shape.IsTuple()) << HumanString(shape); return shape.tuple_shapes_size(); } /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape, int64 index) { - CHECK(IsTuple(shape)); + CHECK(shape.IsTuple()); CHECK_GT(TupleElementCount(shape), index); TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index))); return shape.tuple_shapes(index); @@ -412,7 +358,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start, int64 limit) { TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple)); - CHECK(IsTuple(tuple)); + CHECK(tuple.IsTuple()); CHECK_LE(start, TupleElementCount(tuple)); CHECK_LE(limit, TupleElementCount(tuple)); @@ -429,15 +375,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( complex_shape.element_type())); } -/* static */ bool ShapeUtil::ShapeIs(const Shape& shape, - PrimitiveType element_type, - std::initializer_list dimensions) { - return Equal(shape, MakeShape(element_type, dimensions)); -} - /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - DCHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); - DCHECK_EQ(shape.dimensions_size(), Rank(shape)); + DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape); + DCHECK_EQ(shape.dimensions_size(), shape.rank()); if (shape.dimensions().size() == 1) { return shape.dimensions()[0]; } @@ -447,8 +387,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) { - CHECK(IsArray(shape) || IsTuple(shape)); - if (IsArray(shape)) { + CHECK(shape.IsArray() || shape.IsTuple()); + if (shape.IsArray()) { return ElementsIn(shape); } int64 count = 0; @@ -472,7 +412,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { - return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; + return shape.IsArray() && ElementsIn(shape) == 0; } /* static */ bool ShapeUtil::IsScalarWithElementType( @@ -480,56 +420,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return IsScalar(shape) && shape.element_type() == element_type; } -namespace { - -// Class to memoize the computation of -// absl::AsciiStrToLower(PrimitiveType_Name(p)) -// for all PrimitiveType values "p" -class PrimitiveTypeNameGenerator { - public: - PrimitiveTypeNameGenerator() { - for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { - if (PrimitiveType_IsValid(i)) { - lowercase_name_[i] = absl::AsciiStrToLower( - PrimitiveType_Name(static_cast(i))); - } - } - } - const string& LowercaseName(PrimitiveType t) { - return lowercase_name_[static_cast(t)]; - } - - private: - string lowercase_name_[PrimitiveType_ARRAYSIZE]; -}; - -const string& LowercasePrimitiveTypeName(PrimitiveType s) { - static PrimitiveTypeNameGenerator* gen = new PrimitiveTypeNameGenerator(); - return gen->LowercaseName(s); -} - -StatusOr StringToPrimitiveType(const string& name) { - static std::unordered_map* name_to_type = [] { - static auto* map = new std::unordered_map; - for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { - if (PrimitiveType_IsValid(i)) { - auto value = static_cast(i); - (*map)[LowercasePrimitiveTypeName(value)] = value; - } - } - return map; - }(); - auto found = name_to_type->find(name); - if (found == name_to_type->end()) { - return InvalidArgument("Invalid element type string: \"%s\".", name); - } - return found->second; -} - -} // namespace - /* static */ string ShapeUtil::HumanString(const Shape& shape) { - if (IsTuple(shape)) { + if (shape.IsTuple()) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { @@ -539,12 +431,21 @@ StatusOr StringToPrimitiveType(const string& name) { text += ")"; return text; } - return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", - absl::StrJoin(shape.dimensions(), ","), "]"); + std::vector dim_elements; + for (int i = 0; i < shape.dimensions_size(); ++i) { + if (shape.is_dynamic_dimension(i)) { + dim_elements.push_back(StrCat("<=", shape.dimensions(i))); + } else { + dim_elements.push_back(StrCat(shape.dimensions(i))); + } + } + return StrCat( + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[", + absl::StrJoin(dim_elements, ","), "]"); } /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { - if (IsTuple(shape)) { + if (shape.IsTuple()) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { @@ -554,12 +455,14 @@ StatusOr StringToPrimitiveType(const string& name) { text += ")"; return text; } - string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "["); + string result = StrCat( + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "["); for (int i = 0; i < shape.dimensions().size(); i++) { - StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i)); + StrAppend(&result, (i > 0) ? "," : "", + shape.is_dynamic_dimension(i) ? "<=" : "", shape.dimensions(i)); } result += "]"; - if (!IsScalar(shape) && IsArray(shape)) { + if (!IsScalar(shape) && shape.IsArray()) { if (LayoutUtil::HasLayout(shape)) { StrAppend(&result, LayoutUtil::HumanString(shape.layout())); } @@ -580,155 +483,25 @@ StatusOr StringToPrimitiveType(const string& name) { HumanString(program_shape.result())); } -namespace { -// Parses shapes with simple recursive descent structure -- consumes from the -// front of s and passes that view recursively as required. -StatusOr ParseShapeStringInternal(absl::string_view* s) { - *s = absl::StripLeadingAsciiWhitespace(*s); - - if (absl::ConsumePrefix(s, "(")) { // Tuple. - std::vector shapes; - bool must_end = false; - while (true) { - if (absl::ConsumePrefix(s, ")")) { - break; - } else if (must_end) { - return InvalidArgument("Expected end of tuple; got: \"%s\"", *s); - } - shapes.emplace_back(); - TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); - *s = absl::StripLeadingAsciiWhitespace(*s); - must_end = !absl::ConsumePrefix(s, ","); - } - return ShapeUtil::MakeTupleShape(shapes); - } - - string element_type_string; - string dimensions_string; - string format_string; - string layout_string; - // absl::string_view is not compatible with internal RE2 StringPiece, so - // we convert in to the RE2-consumable type and then consume the corresponding - // amount from our string_view type. - static LazyRE2 shape_pattern = { - "^(\\w*\\d*)\\[([\\d,\\s]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,\\s]+)})" - "?"}; - tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); - if (RE2::Consume(&s_consumable, *shape_pattern, &element_type_string, - &dimensions_string, &format_string, &layout_string)) { - size_t consumed = s->size() - s_consumable.size(); - s->remove_prefix(consumed); - auto string_to_int64 = [&s](absl::string_view input) -> StatusOr { - int64 element; - if (!absl::SimpleAtoi(input, &element)) { - return InvalidArgument( - "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", input, - *s); - } - return element; - }; - - auto comma_list_to_int64s = - [string_to_int64](const string& input) -> StatusOr> { - std::vector results; - for (const auto& piece : absl::StrSplit(input, ',', absl::SkipEmpty())) { - TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); - results.push_back(element); - } - return results; - }; - - // Extract the dimensions. - TF_ASSIGN_OR_RETURN(std::vector dimensions, - comma_list_to_int64s(dimensions_string)); - - // Extract the primitive element type. - TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, - StringToPrimitiveType(element_type_string)); - if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { - return InvalidArgument("Invalid element type string: \"%s\".", - element_type_string); - } - - Shape result; - if (primitive_type == OPAQUE) { - result = ShapeUtil::MakeOpaqueShape(); - } else if (primitive_type == TOKEN) { - result = ShapeUtil::MakeTokenShape(); - } else if (format_string.empty() && layout_string.empty()) { - // Create a shape without a layout set. - TF_ASSIGN_OR_RETURN( - result, ShapeUtil::MakeValidatedShape(primitive_type, dimensions)); - } else if (format_string == "sparse") { - TF_ASSIGN_OR_RETURN(int64 max_elements, string_to_int64(layout_string)); - result = ShapeUtil::MakeShapeWithSparseLayout(primitive_type, dimensions, - max_elements); - } else if (format_string.empty() || format_string == "dense") { - // Extract the layout minor-to-major and set it. - TF_ASSIGN_OR_RETURN(std::vector min2maj, - comma_list_to_int64s(layout_string)); - TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( - primitive_type, dimensions, min2maj)); - } else { - // This should not be reached. - LOG(FATAL) << "Unhandled condition when parsing shape; format: \"" - << format_string << "\", layout: \"" << layout_string << "\""; - } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); - return std::move(result); - } - - return InvalidArgument("Invalid shape string to parse: \"%s\"", *s); -} -} // namespace - -/* static */ StatusOr ShapeUtil::ParseShapeString(absl::string_view s) { - TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); - if (!s.empty()) { - return InvalidArgument("Invalid shape string to parse: \"%s\"", s); - } - return shape; -} - /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, const Shape& rhs) { - CHECK(ShapeUtil::IsArray(lhs)); - CHECK(ShapeUtil::IsArray(rhs)); + CHECK(lhs.IsArray()); + CHECK(rhs.IsArray()); return absl::c_equal(lhs.dimensions(), rhs.dimensions()); } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - return CompareShapes(lhs, rhs, /*compare_layouts=*/false, - /*ignore_fp_precision=*/false); + return Shape::Equal().IgnoreLayout()(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs) { - if (IsArray(lhs)) { - return IsArray(rhs) && SameDimensions(lhs, rhs); - } else if (lhs.element_type() == TUPLE) { - return rhs.element_type() == TUPLE && - absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), - CompatibleIgnoringElementType); - } else { - // Opaque, token, etc types are vacuously compatible. - return lhs.element_type() == rhs.element_type(); - } + return Shape::Equal().IgnoreElementType().IgnoreLayout()(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - if (IsArray(lhs)) { - return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) && - CompatibleIgnoringElementType(lhs, rhs); - } else if (lhs.element_type() == TUPLE) { - return rhs.element_type() == TUPLE && - absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), - CompatibleIgnoringFpPrecision); - } else { - // Opaque, token, etc types are vacuously compatible. - return lhs.element_type() == rhs.element_type(); - } + return Shape::Equal().IgnoreFpPrecision().IgnoreLayout()(lhs, rhs); } /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, @@ -739,7 +512,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape, int64 dimension_number) { if (dimension_number < 0) { - dimension_number += Rank(shape); + dimension_number += shape.rank(); } CHECK_GE(dimension_number, 0); return dimension_number; @@ -776,6 +549,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { return sizeof(double); case C64: return sizeof(complex64); + case C128: + return sizeof(complex128); case TOKEN: // Tokens require no space. return 0; @@ -793,7 +568,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { TF_DCHECK_OK(ValidateShape(shape)); if (shape.element_type() == TUPLE) { return ByteSizeOfTupleIndexTable(shape, pointer_size); - } else if (IsArray(shape)) { + } else if (shape.IsArray()) { int64 byte_size = ByteSizeOfElements(shape); if (LayoutUtil::IsSparseArray(shape)) { byte_size += ByteSizeOfSparseIndices(shape); @@ -819,7 +594,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - CHECK(ShapeUtil::IsArray(shape)); + CHECK(shape.IsArray()); int64 allocated_element_count; if (LayoutUtil::IsSparseArray(shape)) { @@ -835,8 +610,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); CHECK(LayoutUtil::IsSparseArray(shape)); - return LayoutUtil::MaxSparseElements(shape.layout()) * - ShapeUtil::Rank(shape) * sizeof(int64); + return LayoutUtil::MaxSparseElements(shape.layout()) * shape.rank() * + sizeof(int64); } /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( @@ -867,22 +642,22 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { if (shape.dimensions_size() != 0) { return InvalidArgument( "shape has %s element type, but has dimensions field: %s", - LowercasePrimitiveTypeName(shape.element_type()), + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), shape.ShortDebugString()); } if (shape.has_layout()) { return InvalidArgument( "shape has %s element type, but has layout field: %s", - LowercasePrimitiveTypeName(shape.element_type()), + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), shape.ShortDebugString()); } return Status::OK(); } - if (LayoutUtil::IsSparseArray(shape) && Rank(shape) == 0) { + if (LayoutUtil::IsSparseArray(shape) && shape.rank() == 0) { return InvalidArgument("sparse arrays must have rank > 0"); } - for (int64 i = 0; i < Rank(shape); ++i) { + for (int64 i = 0; i < shape.rank(); ++i) { int64 dimension = shape.dimensions(i); if (dimension < 0) { return InvalidArgument( @@ -898,7 +673,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) { VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape); - if (!IsArray(shape)) { + if (!shape.IsArray()) { return Status::OK(); } @@ -919,7 +694,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { return sparse_elements_size; } int64 sparse_indices_size = - MultiplyWithoutOverflow(max_sparse_elements, ShapeUtil::Rank(shape)); + MultiplyWithoutOverflow(max_sparse_elements, shape.rank()); if (sparse_indices_size < 0) { return sparse_indices_size; } @@ -991,7 +766,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { ShapeIndexView index) { const Shape* subshape = &shape; for (auto i : index) { - if (!IsTuple(*subshape) || i >= subshape->tuple_shapes_size() || i < 0) { + if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) { return false; } subshape = &subshape->tuple_shapes(i); @@ -1003,7 +778,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { ShapeIndexView index) { const Shape* return_shape = &shape; for (auto i : index) { - CHECK(IsTuple(*return_shape)) + CHECK(return_shape->IsTuple()) << "Invalid index " << index << " for shape " << shape; return_shape = &return_shape->tuple_shapes(i); } @@ -1014,7 +789,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { const Shape& shape, ShapeIndexView index) { const Shape* return_shape = &shape; for (auto i : index) { - if (!IsTuple(*return_shape) || i < 0 || + if (!return_shape->IsTuple() || i < 0 || i >= return_shape->tuple_shapes_size()) { return InvalidArgument( "Shape index %s not a valid subshape index for tuple with shape %s", @@ -1029,7 +804,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { ShapeIndexView index) { Shape* return_shape = shape; for (auto i : index) { - CHECK(IsTuple(*return_shape)); + CHECK(return_shape->IsTuple()); return_shape = return_shape->mutable_tuple_shapes(i); } return return_shape; @@ -1037,11 +812,11 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { - return !IsTuple(GetSubshape(shape, index)); + return !GetSubshape(shape, index).IsTuple(); } /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { - if (!IsTuple(shape)) { + if (!shape.IsTuple()) { return 1; } int64 count = 0; @@ -1063,10 +838,15 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { } /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) { - CHECK(ShapeUtil::IsArray(shape)); + CHECK(shape.IsArray()); return absl::c_linear_search(shape.dimensions(), 1); } +/* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) { + return FilterDimensions( + [&](int64 dim) -> bool { return shape.dimensions()[dim] != 1; }, shape); +} + namespace { // Helper for ForEachSubshape which visits the subshapes of the given shape in @@ -1075,7 +855,7 @@ Status ForEachSubshapeHelper(const Shape& shape, const ShapeUtil::StatusVisitorFunction& func, ShapeIndex* index) { TF_RETURN_IF_ERROR(func(shape, *index)); - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { index->push_back(i); TF_RETURN_IF_ERROR(ForEachSubshapeHelper( @@ -1092,7 +872,7 @@ Status ForEachMutableSubshapeHelper( Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func, ShapeIndex* index) { TF_RETURN_IF_ERROR(func(shape, *index)); - if (ShapeUtil::IsTuple(*shape)) { + if (shape->IsTuple()) { for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) { index->push_back(i); TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper( @@ -1150,6 +930,10 @@ Status ForEachMutableSubshapeHelper( for (auto dim : Permute(permutation, shape.dimensions())) { new_shape.add_dimensions(dim); } + for (int64 i = 0; i < shape.rank(); i++) { + new_shape.set_dynamic_dimension(permutation[i], + shape.is_dynamic_dimension(i)); + } // If `shape` has a layout, by contract we choose a new layout such that the // transpose defined by this permutation is a bitcast. @@ -1200,8 +984,8 @@ Status ForEachMutableSubshapeHelper( /* static */ std::tuple, std::vector> ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, const Shape& shape_post) { - CHECK(IsArray(shape_pre)); - CHECK(IsArray(shape_post)); + CHECK(shape_pre.IsArray()); + CHECK(shape_post.IsArray()); auto nil = std::make_tuple(false, std::vector(), std::vector()); @@ -1248,7 +1032,7 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, auto unmodified_dim_pair = i < unmodified_dims.size() ? unmodified_dims[i] - : std::make_pair(Rank(shape_pre), Rank(shape_post)); + : std::make_pair(shape_pre.rank(), shape_post.rank()); if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) { return nil; } @@ -1260,8 +1044,8 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, /* static */ std::vector> ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, const Shape& output_shape) { - CHECK(IsArray(input_shape)); - CHECK(IsArray(output_shape)); + CHECK(input_shape.IsArray()); + CHECK(output_shape.IsArray()); // Unmodified dimensions are merely common factors of rank 1. auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), @@ -1311,8 +1095,8 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape) { - CHECK(IsArray(input_shape)); - CHECK(IsArray(output_shape)); + CHECK(input_shape.IsArray()); + CHECK(output_shape.IsArray()); CHECK(LayoutUtil::HasLayout(input_shape)); CHECK(LayoutUtil::HasLayout(output_shape)); @@ -1440,12 +1224,12 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, Shape output_shape_dim0_major = MakeShapeWithDescendingLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions())); - for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) { + for (int64 input_dim = 0; input_dim < input_shape.rank(); ++input_dim) { if (input_shape.dimensions(input_dim) <= 1) { continue; } - std::vector input_unit_index(Rank(input_shape), 0); + std::vector input_unit_index(input_shape.rank(), 0); input_unit_index[input_dim] = 1; int64 logical_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major, @@ -1471,11 +1255,48 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ absl::optional ShapeUtil::AlignLayouts( const Shape& input_shape, const Shape& output_shape) { - CHECK(IsArray(input_shape)); - CHECK(IsArray(output_shape)); + CHECK(input_shape.IsArray()); + CHECK(output_shape.IsArray()); + // Removing trivial dimensions from the shape simplifies the alignment + // algorithm since ones can go in any position. + if (HasDegenerateDimensions(input_shape) || + HasDegenerateDimensions(output_shape)) { + auto simple_output_shape = + AlignLayouts(DropDegenerateDimensions(input_shape), + DropDegenerateDimensions(output_shape)); + if (!simple_output_shape) { + return absl::nullopt; + } + + auto layout = simple_output_shape->layout().minor_to_major(); + // For each one sized dimension in the output, increment the dimension + // numbers in layout that are more minor than the one. + absl::InlinedVector dim_map; + dim_map.reserve(simple_output_shape->rank()); + for (int64 i = 0; i < output_shape.rank(); ++i) { + if (output_shape.dimensions(i) != 1) { + dim_map.push_back(i); + } + } + for (int64& d : layout) { + d = dim_map[d]; + } - int64 input_rank = Rank(input_shape); - int64 output_rank = Rank(output_shape); + // Add the ones in descending order to the layout. Descending layouts tend + // to reduce the number of copies inserted in layout assignment. + for (int64 i = output_shape.rank() - 1; i >= 0; --i) { + if (output_shape.dimensions(i) == 1) { + layout.push_back(i); + } + } + Shape output_shape_with_layout = output_shape; + *output_shape_with_layout.mutable_layout()->mutable_minor_to_major() = + layout; + return output_shape_with_layout; + } + + int64 input_rank = input_shape.rank(); + int64 output_rank = output_shape.rank(); // First, calculate an alignment of the dimensions. A consecutive sequence of // input dimensions and output dimensions belong to the same alignment part if @@ -1521,10 +1342,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, if (input_dimension_product != output_dimension_product) { return absl::nullopt; } + // We also need to store an end element so that we know where the last // alignment part ends. alignment.push_back({input_rank, output_rank}); - // Now check if the physical layout can potentially be aligned to the output // shape by changing the physical layout of the output shape. We need to check // that all dimension numbers that belong to the same alignment part appear @@ -1536,40 +1357,23 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, for (int64 i = 0; i < input_rank;) { int64 current_dimension_number = input_dimension_numbers[i]; - // Skip trivial dimensions with a bound of 1. - if (input_shape.dimensions(current_dimension_number) == 1) { - ++i; - continue; - } - - // Calculate the number of non-trivial dimension bounds in the input shape - // belonging to the current alignment part. + // Trivial dimensions are stripped. + CHECK_NE(input_shape.dimensions(current_dimension_number), 1); const int64 current_alignment_index = dimension_to_alignment_index[current_dimension_number]; // Because of the special end element that we added, we can be sure that // 'current_alignment_index' is < alignment.size() - 1. CHECK_LT(current_alignment_index, alignment.size() - 1); - int64 num_non_trivial_dimensions_in_alignment_part = 0; - for (int64 j = alignment[current_alignment_index].first; - j < alignment[current_alignment_index + 1].first; ++j) { - if (input_shape.dimensions(j) != 1) { - ++num_non_trivial_dimensions_in_alignment_part; - } - } // Check that the following 'num_non_trivial_dimensions_in_alignment_part' // dimension numbers (ignoring dimension numbers with dimension bound 1) are // in descending order and belong to the current alignment part. - for (int64 j = 0; j < num_non_trivial_dimensions_in_alignment_part; + for (int64 j = 0; j < alignment[current_alignment_index + 1].first - + alignment[current_alignment_index].first; ++i, ++j) { if (i == input_rank) { return absl::nullopt; } - // Skip trivial dimensions with a bound of 1. - if (input_shape.dimensions(input_dimension_numbers[i]) == 1) { - --j; - continue; - } // If the current dimension number belongs to a different alignment part, // or the dimension numbers are not in descending order, we can return // early. @@ -1580,22 +1384,11 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, } current_dimension_number = input_dimension_numbers[i]; } - // The output dimension numbers that belong to the current alignment part - // need to appear in the same descending order as in the input. Again, we - // can skip dimensions with a bound of 1. + // need to appear in the same descending order as in the input. for (int64 j = alignment[current_alignment_index + 1].second - 1; j >= alignment[current_alignment_index].second; --j) { - if (output_shape.dimensions(j) != 1) { - output_layout.push_back(j); - } - } - } - // Now add all the dimensions with dimension bound 1 at the end of - // 'output_layout'. - for (int64 i = 0; i < output_rank; ++i) { - if (output_shape.dimensions(i) == 1) { - output_layout.push_back(i); + output_layout.push_back(j); } } CHECK_EQ(output_layout.size(), output_rank); @@ -1612,30 +1405,14 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { - CHECK(IsArray(shape)); - shape.mutable_dimensions()->erase(shape.mutable_dimensions()->begin() + - dim_to_delete); - if (LayoutUtil::HasLayout(shape)) { - Layout* layout = shape.mutable_layout(); - layout->set_format(DENSE); - for (size_t i = 0; i < layout->minor_to_major().size();) { - if (layout->minor_to_major(i) == dim_to_delete) { - layout->mutable_minor_to_major()->erase( - layout->minor_to_major().begin() + i); - continue; - } - if (layout->minor_to_major(i) > dim_to_delete) { - (*layout->mutable_minor_to_major())[i] -= 1; - } - ++i; - } - } + CHECK(shape.IsArray()); + shape.DeleteDimension(dim_to_delete); return shape; } /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { - CHECK(IsArray(shape)); + CHECK(shape.IsArray()); std::vector dims_to_delete; for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) { if (!p(i)) { @@ -1655,8 +1432,11 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, size_t hash_value = hash()(shape.element_type()); if (shape.tuple_shapes().empty()) { - for (int64 dim : shape.dimensions()) { - hash_value = Hash64Combine(hash_value, hash()(dim)); + for (int i = 0; i < shape.dimensions_size(); ++i) { + hash_value = + Hash64Combine(hash_value, hash()(shape.dimensions(i))); + hash_value = Hash64Combine(hash_value, + hash()(shape.is_dynamic_dimension(i))); } hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout())); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 84a27f662a57ba274562e2e9be57b7e971c9b477..7f610a6085d6fbe3d3143d5027cdc43d4b07bcbf 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -185,7 +185,7 @@ class ShapeUtil { // may not actually be able to store this number of elements. See // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of // elements that can be stored in a sparse shape. - // Precondition: IsArray(shape) + // Precondition: shape.IsArray() static int64 ElementsIn(const Shape& shape); // As ElementsIn(), but recurses through tuples. @@ -207,7 +207,7 @@ class ShapeUtil { // Returns the number of bytes used to store the primitive_type. // - // Precondition: ShapeUtil::IsArray(shape) + // Precondition: shape.IsArray() static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); // Returns the number of bytes required to store the tuple member pointers for @@ -241,10 +241,6 @@ class ShapeUtil { // (param_name: f32[42x12], ...) -> f32[24x42] static string HumanString(const ProgramShape& program_shape); - // Parses a ShapeUtil::HumanString-format shape string back into a shape - // object. - static StatusOr ParseShapeString(absl::string_view s); - // Returns whether the LHS and RHS shapes have the same dimensions; note: does // not check element type. // Precondition: IsArray(lhs) && IsArray(rhs) @@ -266,7 +262,7 @@ class ShapeUtil { } // Returns the higher-precision element type if a and b are both floating - // point types; otherwise, checks that that they have the same element type + // point types; otherwise, checks that they have the same element type // and returns it. static PrimitiveType HigherPrecisionElementType(const Shape& a, const Shape& b) { @@ -294,16 +290,12 @@ class ShapeUtil { // being F32. Tuple elements are compared recursively for compatibility. static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); - // Returns whether the lhs and rhs shapes are identical protobufs. + // Returns whether the lhs and rhs shapes are identical. static bool Equal(const Shape& lhs, const Shape& rhs); // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); - // Returns the rank (number of dimensions) of the given shape. - // Precondition: !IsTuple(shape) - static int64 Rank(const Shape& shape); - // Returns the number of dimensions for which the dimension is not (trivially) // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just // fluff. Note that zero dimensions are included in the true rank, e.g., @@ -317,10 +309,10 @@ class ShapeUtil { // Scalar-specific static bool IsScalar(const Shape& shape) { - return IsArray(shape) && Rank(shape) == 0; + return shape.IsArray() && shape.rank() == 0; } static bool IsEffectiveScalar(const Shape& shape) { - return IsArray(shape) && TrueRank(shape) == 0; + return shape.IsArray() && TrueRank(shape) == 0; } // Returns whether "shape" is a scalar (array) with the given element_type. @@ -375,11 +367,24 @@ class ShapeUtil { static Shape MakeShape(PrimitiveType element_type, absl::Span dimensions); + // Constructs a new shape with the given element type and sequence of + // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates + // with a true value that the respective dimension is dynamic. If the + // dimension is dynamic then the respective value in 'dimension' is an upper + // bound on the dimension size. 'dimensions' and 'dynamic_dimensions' must be + // the same size. + static Shape MakeShape(PrimitiveType element_type, + absl::Span dimensions, + const std::vector& dynamic_dimensions); + // Constructs a new shape with the given element type and sequence of // dimensions. Method checks if the element type is valid and the shape's // size fits in std::numeric_limits::max(). static StatusOr MakeValidatedShape(PrimitiveType element_type, absl::Span dimensions); + static StatusOr MakeValidatedShape( + PrimitiveType element_type, absl::Span dimensions, + const std::vector& dynamic_dimensions); // Creates a Shape with element type corresponding to T and the given // dimensions @@ -393,7 +398,9 @@ class ShapeUtil { // Returns a value shape such that shape.has_layout(). static Shape MakeShapeWithLayout(PrimitiveType element_type, absl::Span dimensions, - absl::Span minor_to_major); + absl::Span minor_to_major, + absl::Span tiles = {}, + int64 element_size_in_bits = 0); static Shape MakeShapeWithSparseLayout(PrimitiveType element_type, absl::Span dimensions, @@ -447,27 +454,6 @@ class ShapeUtil { // that floating point numbers are signed. static bool ElementIsSigned(const Shape& shape); - // Returns whether the shape is a tuple. - static bool IsTuple(const Shape& shape) { - return shape.element_type() == TUPLE; - } - - // Returns whether the shape is an opaque value (i.e. an 'existential' typed - // value that is passed to CustomCall operations). - static bool IsOpaque(const Shape& shape) { - return shape.element_type() == OPAQUE; - } - - // Returns whether the shape is an token value used for ordering - // side-effecting operations. - static bool IsToken(const Shape& shape) { - return shape.element_type() == TOKEN; - } - - // Returns whether the shape is an array. Note that scalars are considered - // arrays. - static bool IsArray(const Shape& shape); - // Returns whether the given primitive type corresponds to an array shape. static bool IsArrayPrimitiveType(PrimitiveType primitive_type); @@ -497,12 +483,6 @@ class ShapeUtil { // shape. static Shape ComplexComponentShape(const Shape& complex_shape); - // Shorthand for testing whether a shape is of a given element type and - // sequence of dimensions. - ABSL_DEPRECATED("Use Equal() instead.") - static bool ShapeIs(const Shape& shape, PrimitiveType element_type, - std::initializer_list dimensions); - // Returns true if the given shape has a subshape at the given index. static bool IndexIsValid(const Shape& shape, ShapeIndexView index); @@ -551,6 +531,9 @@ class ShapeUtil { // (dimensions with bound 1). static bool HasDegenerateDimensions(const Shape& shape); + // Drops any degenerate dimensions (i.e. dimensions of size 1) + static Shape DropDegenerateDimensions(const Shape& shape); + // Permutes the dimensions by the given permutation, so // return_value.dimensions[permutation[i]] = argument.dimensions[i]. // @@ -694,11 +677,9 @@ class ShapeUtil { template static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { - ForEachIndexWithStatus(shape, - [&](absl::Span indices) { - return StatusOr(visitor_function(indices)); - }) - .IgnoreError(); + ForEachIndexWithStatus(shape, [&](absl::Span indices) { + return StatusOr(visitor_function(indices)); + }).IgnoreError(); } // A parallel version of ForEachIndex(WithStatus). This can only be used if @@ -747,7 +728,7 @@ class ShapeUtil { if (ShapeUtil::IsZeroElementArray(shape)) { return Status::OK(); } - CHECK_EQ(Rank(shape), base.size()); + CHECK_EQ(shape.rank(), base.size()); CHECK_EQ(incr.size(), base.size()); CHECK_EQ(count.size(), base.size()); const int64 rank = LayoutUtil::MinorToMajor(shape).size(); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 60bdbe302045e6f3b4bae500c50bc68fb217525d..020b062f6b1b032bab958772d3a6a1e35daee38b 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -82,102 +82,6 @@ TEST(ShapeUtilTest, Rank4DimensionIndexing) { ASSERT_EQ(3, shape.dimensions(0)); } -TEST(ShapeUtilTest, ParseShapeStringR2F32) { - string shape_string = "f32[123,456]"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { - string shape_string = "(f32[1572864],s8[5120,1024])"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), - ShapeUtil::MakeShape(S8, {5120, 1024})}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { - string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeTupleShape({ - ShapeUtil::MakeShape(F32, {1}), - ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), - ShapeUtil::MakeOpaqueShape(), - ShapeUtil::MakeShape(F32, {3}), - }); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithLayout) { - string shape_string = "f32[123,456]{0,1}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithExplicitDenseLayout) { - string shape_string = "f32[123,456]dense{0,1}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { - string shape_string = "f32[123,456]sparse{10}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseOpaqueType) { - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString("opaque[]")); - Shape expected = ShapeUtil::MakeOpaqueShape(); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseTokenType) { - TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]")); - Shape expected = ShapeUtil::MakeTokenShape(); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseInvalidShapeString) { - string shape_strings[] = { - "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", - "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", - }; - for (const string& shape_string : shape_strings) { - StatusOr result = ShapeUtil::ParseShapeString(shape_string); - ASSERT_FALSE(result.ok()) << "shape: " << shape_string; - } -} - TEST(ShapeUtilTest, CompatibleIdenticalShapes) { Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2}); Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); @@ -272,6 +176,28 @@ TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) { ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1}))); } +TEST(ShapeUtilTest, EqualDynamicShapes) { + EXPECT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 3}, {true, false}), + ShapeUtil::MakeShape(F32, {4, 3}, {true, false}))); + EXPECT_FALSE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 3}, {true, false}), + ShapeUtil::MakeShape(F32, {4, 3}, {false, false}))); +} + +TEST(ShapeUtilTest, CompatibleDynamicShapes) { + Shape shape_a = ShapeUtil::MakeShape(F32, {4, 3}, {true, false}); + *shape_a.mutable_layout() = Layout({1, 0}); + Shape shape_b = ShapeUtil::MakeShape(F32, {4, 3}, {true, false}); + *shape_b.mutable_layout() = Layout({0, 1}); + Shape shape_c = ShapeUtil::MakeShape(F32, {4, 3}, {false, true}); + *shape_c.mutable_layout() = Layout({0, 1}); + + EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_a)); + EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_b)); + EXPECT_FALSE(ShapeUtil::Compatible(shape_a, shape_c)); +} + TEST(ShapeUtilTest, CompatibleTuples) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); @@ -612,10 +538,6 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2))); } -TEST(ShapeUtilTest, ShapeIs) { - EXPECT_FALSE(ShapeUtil::ShapeIs(ShapeUtil::MakeShape(PRED, {2}), PRED, {})); -} - TEST(ShapeUtilTest, ForEachIndex) { struct ShapeDimensionAndNumberInvocations { std::vector dimensions; @@ -788,6 +710,26 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) { } while (std::next_permutation(layout.begin(), layout.end())); } +TEST(ShapeUtilTest, PermuteDynamicDimensions) { + Shape shape = + ShapeUtil::MakeShape(F32, {10, 100, 1000}, + /*dynamic_dimensions*/ {false, true, true}); + SCOPED_TRACE(absl::StrCat("shape=", shape.ToString())); + + std::vector permutation(3); + std::iota(permutation.begin(), permutation.end(), 0); + do { + SCOPED_TRACE(absl::StrCat("permutation=", absl::StrJoin(permutation, ","))); + + auto permuted = ShapeUtil::PermuteDimensions(permutation, shape); + for (int i = 0; i < shape.rank(); i++) { + EXPECT_EQ(permuted.dimensions(permutation[i]), shape.dimensions(i)); + EXPECT_EQ(permuted.is_dynamic_dimension(permutation[i]), + shape.is_dynamic_dimension(i)); + } + } while (std::next_permutation(permutation.begin(), permutation.end())); +} + TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), @@ -819,8 +761,15 @@ TEST(AlignmentTest, AlignLayoutsWithTrivialDimensions) { auto aligned_shape = ShapeUtil::AlignLayouts( input, ShapeUtil::MakeShape(xla::F32, {1, 4, 1, 3, 2, 7, 5, 11, 1})); EXPECT_TRUE(aligned_shape); - EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), - ElementsAre(6, 5, 4, 3, 1, 7, 0, 2, 8)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); +} + +TEST(AlignmentTest, AlignLayoutsWithAllTrivialDimensions) { + Shape input = + ShapeUtil::MakeShapeWithLayout(xla::F32, {1, 1, 1, 1}, {0, 1, 3, 2}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {1, 1, 1, 1, 1})); + EXPECT_TRUE(aligned_shape); EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); } diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc index a40bb7875e7ea53a8959a9a67ec09ec260ba9c37..82091bdee65c709bb6020f40acc15f13d8599c1d 100644 --- a/tensorflow/compiler/xla/sparse_index_array.cc +++ b/tensorflow/compiler/xla/sparse_index_array.cc @@ -79,7 +79,7 @@ void SparseIndexArray::Resize(int64 num_indices) { } bool SparseIndexArray::Validate(const Shape& shape) const { - if (rank_ == 0 || rank_ != ShapeUtil::Rank(shape)) { + if (rank_ == 0 || rank_ != shape.rank()) { return false; } int64 num_indices = index_count(); diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h index a96d483462efd77ae4761541e8c79b2c84fa49f3..0c25355467da3fd346d80db790d78252869975ef 100644 --- a/tensorflow/compiler/xla/sparse_index_array.h +++ b/tensorflow/compiler/xla/sparse_index_array.h @@ -135,7 +135,7 @@ void SparseIndexArray::SortWithValues(absl::Span values) { auto sort_order_less = [this](int64 lhs, int64 rhs) { return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0; }; - std::sort(sort_order.begin(), sort_order.end(), sort_order_less); + absl::c_sort(sort_order, sort_order_less); // Reorder the array elements according to sort_order. Work through the array // and follow cycles so we can do the reorder in-place. diff --git a/tensorflow/compiler/xla/status_macros.cc b/tensorflow/compiler/xla/status_macros.cc index b88fe367d7416a26c1147fd5e10fb20772814fe5..aa7238f07d432aabb44d2cbed66786217e6a846c 100644 --- a/tensorflow/compiler/xla/status_macros.cc +++ b/tensorflow/compiler/xla/status_macros.cc @@ -25,6 +25,13 @@ limitations under the License. namespace xla { namespace status_macros { +ABSL_CONST_INIT const char kPossibleAutoJitAlternative[] = + "This error might be occurring with the use of xla.compile. If it is not " + "necessary that every Op be compiled with XLA, an alternative is to use " + "auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment " + "variable TF_XLA_FLAGS=\"tf_xla_auto_jit=2\" which will attempt to use xla " + "to compile as much of the graph as the compiler is able to."; + static Status MakeStatus(tensorflow::error::Code code, const string& message) { return Status(code, message); } diff --git a/tensorflow/compiler/xla/status_macros.h b/tensorflow/compiler/xla/status_macros.h index e51dd64e2a3dc7c359918cb08c6c94b2b4d9e91b..315136acc71670fa3ad48da4dc064e384ddadaa9 100644 --- a/tensorflow/compiler/xla/status_macros.h +++ b/tensorflow/compiler/xla/status_macros.h @@ -30,6 +30,10 @@ limitations under the License. namespace xla { namespace status_macros { +// This is a useful error message when encountering XLA Compiler errors that +// could be handled with the non-strict AutoJit mode. +extern const char kPossibleAutoJitAlternative[]; + // Stream object used to collect error messages in MAKE_ERROR macros // or append error messages with APPEND_ERROR. It accepts any // arguments with operator<< to build an error string, and then has an diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 5a7a4faa7e89b27fb537f20d94c21cb4a76e000d..562854756628df64fbf92d40af859f8b218b0cc2 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1,6 +1,13 @@ # Description: # Base testing infrastructure for XLA. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") + licenses(["notice"]) # Apache 2.0 package( @@ -23,17 +30,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test_library") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load( - "//tensorflow/core:platform/default/build_config_root.bzl", - "tf_cuda_tests_tags", -) - # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -75,6 +71,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", @@ -280,9 +277,6 @@ cc_library( xla_test( name = "bad_rng_shape_validation_test", srcs = ["bad_rng_shape_validation_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -319,6 +313,31 @@ xla_test( ], ) +xla_test( + name = "conv_depthwise_backprop_filter_test", + timeout = "long", + srcs = ["conv_depthwise_backprop_filter_test.cc"], + # these backends do not natively handle batch group counts. + blacklisted_backends = [ + "gpu", + "cpu", + ], + shard_count = 6, + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:bfloat16_normalization", + "//tensorflow/compiler/xla/service:despecializer", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:optional", + ], +) + xla_test( name = "grouped_convolution_test", timeout = "long", @@ -348,9 +367,6 @@ xla_test( xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -371,9 +387,6 @@ xla_test( xla_test( name = "query_inferred_shape_test", srcs = ["query_inferred_shape_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -391,9 +404,6 @@ xla_test( xla_test( name = "while_test", srcs = ["while_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -417,6 +427,10 @@ xla_test( xla_test( name = "xla_hlo_profile_test", srcs = ["xla_hlo_profile_test.cc"], + blacklisted_backends = [ + # Hlo profiles are not supported on the interpreter backend. + "interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_util", @@ -440,9 +454,6 @@ xla_test( xla_test( name = "axpy_simple_test", srcs = ["axpy_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -457,7 +468,6 @@ xla_test( xla_test( name = "map_test", srcs = ["map_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -510,9 +520,6 @@ xla_test( xla_test( name = "pred_test", srcs = ["pred_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla/client:local_client", @@ -528,9 +535,6 @@ xla_test( xla_test( name = "select_test", srcs = ["select_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", @@ -548,7 +552,6 @@ xla_test( xla_test( name = "conditional_test", srcs = ["conditional_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", @@ -566,7 +569,6 @@ xla_test( xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", @@ -627,9 +629,6 @@ xla_test( xla_test( name = "deconstruct_tuple_test", srcs = ["deconstruct_tuple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -652,7 +651,6 @@ xla_test( name = "array_elementwise_ops_test", srcs = ["array_elementwise_ops_test.cc"], shard_count = 25, - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -677,22 +675,19 @@ xla_test( xla_test( name = "exhaustive_f32_elementwise_op_test", - size = "enormous", srcs = ["exhaustive_f32_elementwise_op_test.cc"], - backends = [ - "cpu", - "gpu", - ], + real_hardware_only = True, # Very slow on the interpreter. shard_count = 48, tags = [ - "broken", - "manual", - "notap", + "optonly", + # This is a big test that we skip for capacity reasons in OSS testing. + "no_oss", ], deps = [ ":client_library_test_base", ":literal_test_util", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/base", @@ -702,7 +697,6 @@ xla_test( xla_test( name = "reduce_precision_test", srcs = ["reduce_precision_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -729,7 +723,6 @@ xla_test( srcs = ["dot_operation_test.cc"], shard_count = 20, tags = [ - "enable_for_xla_interpreter", "optonly", ], deps = [ @@ -739,7 +732,9 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -796,7 +791,9 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -810,9 +807,6 @@ xla_test( xla_test( name = "transpose_test", srcs = ["transpose_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -832,9 +826,6 @@ xla_test( xla_test( name = "constants_test", srcs = ["constants_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -845,7 +836,9 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -954,6 +947,11 @@ xla_test( xla_test( name = "batch_normalization_test", srcs = ["batch_normalization_test.cc"], + blacklisted_backends = [ + # BatchNorm HLOs are not handled by the interpreter backend, and the + # BatchNorm expander is not run on the interpreter. + "interpreter", + ], shard_count = 40, deps = [ ":test_utils", @@ -1045,9 +1043,6 @@ xla_test( name = "slice_test", srcs = ["slice_test.cc"], shard_count = 40, - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -1068,9 +1063,6 @@ xla_test( xla_test( name = "multidimensional_slice_test", srcs = ["multidimensional_slice_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1088,9 +1080,6 @@ xla_test( name = "dynamic_ops_test", timeout = "moderate", srcs = ["dynamic_ops_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -1116,9 +1105,6 @@ xla_test( xla_test( name = "tuple_test", srcs = ["tuple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -1142,9 +1128,6 @@ xla_test( xla_test( name = "vector_ops_reduce_test", srcs = ["vector_ops_reduce_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1163,9 +1146,8 @@ xla_test( xla_test( name = "reduce_test", srcs = ["reduce_test.cc"], - shard_count = 40, + shard_count = 31, tags = [ - "enable_for_xla_interpreter", "optonly", ], deps = [ @@ -1232,7 +1214,6 @@ xla_test( srcs = [], shard_count = 20, tags = [ - "enable_for_xla_interpreter", "optonly", ], xla_test_library_deps = [":reduce_window_test_library"], @@ -1244,7 +1225,6 @@ xla_test( timeout = "long", srcs = ["select_and_scatter_test.cc"], tags = [ - "enable_for_xla_interpreter", "optonly", ], deps = [ @@ -1270,9 +1250,6 @@ xla_test( xla_test( name = "copy_test", srcs = ["copy_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla:array2d", @@ -1293,9 +1270,6 @@ xla_test( xla_test( name = "reduce_hlo_test", srcs = ["reduce_hlo_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1309,9 +1283,6 @@ xla_test( xla_test( name = "token_hlo_test", srcs = ["token_hlo_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", @@ -1326,9 +1297,6 @@ xla_test( xla_test( name = "call_test", srcs = ["call_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -1348,6 +1316,7 @@ xla_test( xla_test( name = "custom_call_test", srcs = ["custom_call_test.cc"], + backends = ["cpu"], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -1370,9 +1339,6 @@ xla_test( xla_test( name = "binop_scaling_test", srcs = ["binop_scaling_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1390,9 +1356,6 @@ xla_test( xla_test( name = "broadcast_simple_test", srcs = ["broadcast_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1412,9 +1375,6 @@ xla_test( xla_test( name = "pad_test", srcs = ["pad_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1434,11 +1394,8 @@ xla_test( ) xla_test( - name = "fmax_test", - srcs = ["fmax_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], + name = "fmax_fmin_test", + srcs = ["fmax_fmin_test.cc"], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1453,9 +1410,6 @@ xla_test( xla_test( name = "log_test", srcs = ["log_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1470,9 +1424,6 @@ xla_test( xla_test( name = "matrix_ops_simple_test", srcs = ["matrix_ops_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -1519,9 +1470,6 @@ xla_test( name = "reshape_test", srcs = ["reshape_test.cc"], shard_count = 30, - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1547,9 +1495,6 @@ xla_test( xla_test( name = "reverse_test", srcs = ["reverse_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1568,9 +1513,6 @@ xla_test( xla_test( name = "vector_ops_simple_test", srcs = ["vector_ops_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:shape_util", @@ -1594,9 +1536,6 @@ xla_test( xla_test( name = "concat_test", srcs = ["concat_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1617,9 +1556,6 @@ xla_test( xla_test( name = "convert_test", srcs = ["convert_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1637,8 +1573,12 @@ xla_test( ) xla_test( - name = "cross_replica_sum_test", - srcs = ["cross_replica_sum_test.cc"], + name = "all_reduce_test", + srcs = ["all_reduce_test.cc"], + blacklisted_backends = [ + # All reduce is not supported on the interpreter backend. + "interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1663,9 +1603,6 @@ xla_test( xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1705,9 +1642,6 @@ xla_test( xla_test( name = "floor_ceil_test", srcs = ["floor_ceil_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1769,6 +1703,10 @@ xla_test( xla_test( name = "execution_profile_test", srcs = ["execution_profile_test.cc"], + blacklisted_backends = [ + # Execution profiles are not supported on the interpreter backend. + "interpreter", + ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", @@ -1783,6 +1721,10 @@ xla_test( name = "execution_profile_test_with_xla_hlo_profile", srcs = ["execution_profile_test.cc"], args = ["--xla_hlo_profile"], + blacklisted_backends = [ + # Hlo profiles are not supported on the interpreter backend. + "interpreter", + ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", @@ -1796,9 +1738,6 @@ xla_test( xla_test( name = "replay_test", srcs = ["replay_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", @@ -1821,9 +1760,6 @@ xla_test( xla_test( name = "broadcast_test", srcs = ["broadcast_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1885,9 +1821,6 @@ xla_test( xla_test( name = "fusion_test", srcs = ["fusion_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -2005,6 +1938,10 @@ xla_test( xla_test( name = "outfeed_in_nested_computation_test", srcs = ["outfeed_in_nested_computation_test.cc"], + blacklisted_backends = [ + # Outfeed ops are not supported on the interpreter backend. + "interpreter", + ], deps = [ "//tensorflow/compiler/xla/tests:local_client_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -2181,7 +2118,6 @@ xla_test( srcs = ["iota_test.cc"], shard_count = 30, tags = [ - "enable_for_xla_interpreter", # Require optimized builds, iota_test_cpu is very slow in fastbuild. "optonly", ], @@ -2209,3 +2145,41 @@ tf_cc_test( "@com_google_absl//absl/synchronization", ], ) + +xla_test( + name = "ptxas_bug_120501638", + srcs = ["ptxas_bug_120501638.cc"], + tags = [ + # Disabled in OSS until nvidia publicly releases a fixed ptxas. + "no_oss", + ], + deps = [ + ":hlo_test_base", + ":xla_internal_test_main", # fixdeps: keep + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:test", + ], +) + +xla_test( + name = "triangular_solve_test", + srcs = ["triangular_solve_test.cc"], + tags = [ + "enable_for_xla_interpreter", + "noasan", # sometimes times out, http://b/78650012 + ], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/all_reduce_test.cc similarity index 94% rename from tensorflow/compiler/xla/tests/cross_replica_sum_test.cc rename to tensorflow/compiler/xla/tests/all_reduce_test.cc index 410732c07b7b6d3ece33ab11f4778241dc53ca50..7e695f829e39831e2c8558cb07d0689e560bbafa 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/all_reduce_test.cc @@ -41,7 +41,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { ENTRY test_computation { p = f32[3] parameter(0) - ROOT crs = f32[3] cross-replica-sum(p), to_apply=add + ROOT crs = f32[3] all-reduce(p), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); @@ -62,7 +62,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { ENTRY test_computation { p0 = f32[3] parameter(0) p1 = f32[2] parameter(1) - ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add + ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); @@ -88,7 +88,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { ENTRY test_computation { p0 = f32[3] parameter(0) p1 = f32[2] constant({10, 20}) - ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add + ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 915b456b52215f8d6a9eb6c5b933f3502f1d3d2c..acdd3c9da92efe8fae1336eaa861c01d5bb9b158 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -1443,6 +1442,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, PowC64s) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + auto lhs = + ConstantR1(&builder, {-2.0f, -0.6f, -0.6f, 0.0f, 0.0f, 0.0f}); + auto rhs = + ConstantR1(&builder, {0.5f, 0.6f, -0.6f, 0.5f, 0.6f, 0.0f}); + Pow(lhs, rhs); + + ComputeAndCompareR1(&builder, + { + {0, 1.41421356}, + {-2.27443288e-01, 0.69999846}, + {-4.19847531e-01, -1.29215783}, + {0, 0}, + {0, 0}, + {1, 0}, + }, + {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {}); @@ -2047,6 +2067,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, ClampF32) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + auto minimum = ConstantR1(&builder, {1.0f, -6.5f, 1.0f, 2.25f, NAN}); + auto argument = + ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); + auto maximum = ConstantR1(&builder, {3.0f, 0.5f, 25.5f, NAN, 123.0f}); + Clamp(minimum, argument, maximum); + + ComputeAndCompareR1(&builder, {2.0f, 0.5f, 1.0f, NAN, NAN}, {}, + error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { XlaBuilder builder(TestName()); auto minimum = ConstantR0(&builder, 0.0f); diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index e9728e636f0ee032416b2da17a3ea83c5bb18083..63e48117056dec4af603cbc85e478fcb15ad0cec 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -76,7 +76,9 @@ XLA_TEST_F(Bfloat16Test, NegateScalarF16) { error_spec_); } -XLA_TEST_F(Bfloat16Test, BatchNormTraining) { +// Disabled on interpreter since BatchNormExanper is not run by default on the +// intepreter backend. +XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormTraining)) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); @@ -110,7 +112,9 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02)); } -XLA_TEST_F(Bfloat16Test, BatchNormGrad) { +// Disabled on interpreter since BatchNormExanper is not run by default on the +// intepreter backend. +XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormGrad)) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 05d4d04034bf50c8bb840e59b28a590fce048c19..c14d279ac560db33066ae4fc68b6290f7499bb39 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -34,6 +34,7 @@ def xla_test( xla_test_library_deps = [], backends = [], blacklisted_backends = [], + real_hardware_only = False, args = [], tags = [], copts = [], @@ -108,6 +109,10 @@ def xla_test( use for that target. **kwargs: Additional keyword arguments to pass to native.cc_test. """ + + # All of the backends in all_backends are real hardware. + _ignore = [real_hardware_only] + test_names = [] if not backends: backends = all_backends diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 12c029983336cc9aed0fde4ce6881c9a00a9869e..0e99ede5d01fcfa88c54c9cbc5a6a85bf8f15ddf 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include #include #include "absl/memory/memory.h" @@ -40,8 +41,9 @@ constexpr char kInterpreter[] = "interpreter"; // Wrapper function that creates a nicer error message (than a bare // ValueOrDie()) if the platform we intend to test is not available. -Client* GetOrCreateLocalClientOrDie(const LocalClientOptions& client_options) { - StatusOr result = +LocalClient* GetOrCreateLocalClientOrDie( + const LocalClientOptions& client_options) { + StatusOr result = ClientLibrary::GetOrCreateLocalClient(client_options); TF_CHECK_OK(result.status()) << " could not create local client for testing"; return result.ValueOrDie(); @@ -74,6 +76,9 @@ ClientLibraryTestBase::ClientLibraryTestBase( // default. execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "constant_folding"); + + execution_options_.mutable_debug_options() + ->set_xla_hlo_evaluator_use_fast_path(true); } ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) @@ -88,6 +93,9 @@ ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "constant_folding"); + + execution_options_.mutable_debug_options() + ->set_xla_hlo_evaluator_use_fast_path(true); } string ClientLibraryTestBase::TestName() const { @@ -184,7 +192,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( verify_output(actual, ""); // Try with all output layouts. - std::vector minor_to_major(ShapeUtil::Rank(expected.shape())); + std::vector minor_to_major(expected.shape().rank()); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto layout = ShapeUtil::MakeShapeWithLayout( @@ -217,7 +225,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_ASSIGN_OR_RETURN(auto literal, client_->Transfer(*arguments[index], nullptr)); // Skip tuples because they don't have a rank. - if (ShapeUtil::IsTuple(literal.shape())) { + if (literal.shape().IsTuple()) { layout_strings.push_back( ShapeUtil::HumanStringWithLayout(literal.shape())); arguments_with_layout.push_back(arguments[index]); @@ -227,7 +235,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return Status::OK(); } - std::vector minor_to_major(ShapeUtil::Rank(literal.shape())); + std::vector minor_to_major(literal.shape().rank()); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto literal_relayout = @@ -273,9 +281,10 @@ StatusOr ClientLibraryTestBase::ComputeAndTransfer( if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - owning_arguments.push_back( - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) - .ValueOrDie()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr owned_argument, + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } } @@ -296,9 +305,10 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - owning_arguments.push_back( - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) - .ValueOrDie()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr owned_argument, + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } } @@ -356,9 +366,10 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - owning_arguments.push_back( - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) - .ValueOrDie()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr owned_argument, + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 65a23dd883594b9bf9c37494a37e9be39b197788..d700437ed355c144639f76d683055e211975fde9 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -385,8 +385,8 @@ class ClientLibraryTestBase : public ::testing::Test { StatusOr> ComputeValueAndReference( XlaBuilder* builder, absl::Span arguments); - Client* client_; - Client* ref_client_; // To compute reference result. + LocalClient* client_; + LocalClient* ref_client_; // To compute reference result. ExecutionOptions execution_options_; private: @@ -431,7 +431,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, @@ -455,7 +456,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, @@ -480,7 +482,8 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); @@ -506,7 +509,8 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); @@ -532,7 +536,8 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 363dee74b2755a6bdc3c5a5164a85378581c21d2..247328b730f3af936d933f824da491b593b27c90 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -96,7 +96,7 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, LiteralSlice(result, {1})); - EXPECT_TRUE(ShapeUtil::IsTuple(result.shape())); + EXPECT_TRUE(result.shape().IsTuple()); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape())); EXPECT_TRUE(ShapeUtil::Equal( @@ -109,7 +109,10 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { /*minor_to_major=*/{1, 0}))); } -XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { +// Disabled for interpreter since ExecuteAsyncOnStream is not implemented on +// interpreter backend. +XLA_TEST_F(ClientTest, + DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(ExecuteParallel))) { XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 3b0414a6045a7c5f4f75948d8ccf2775c575626e..ef800b8ef624bf1020ff1e6857c13b0387482cd3 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -151,19 +151,35 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { } } -TEST_F(ComputeConstantTest, IndirectParamMissing) { +TEST_F(ComputeConstantTest, GetDimensionSize) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = - Add(ConstantR0(&b, 1.0f), - Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param")); - EXPECT_FALSE(IsConstant(computation, &b)); + auto add = + Add(ConstantR1(&b, {1.0f}), ConstantR1(&b, {1.0f})); + auto get_dimension_size = GetDimensionSize(add, 0); + EXPECT_TRUE(IsConstant(get_dimension_size, &b)); + + TF_ASSERT_OK_AND_ASSIGN(auto value, ComputeConstantScalar( + client, get_dimension_size, &b)); + EXPECT_EQ(value, 1); + } +} - auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE( - absl::StrContains(value.status().ToString(), "depends on a parameter")) - << value.status(); +TEST_F(ComputeConstantTest, MultipleGetDimensionSize) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto add = + Add(ConstantR2(&b, {{1.0f}}), ConstantR2(&b, {{1.0f}})); + auto get_dimension_size = GetDimensionSize(add, 0); + auto get_dimension_size_2 = GetDimensionSize(add, 0); + auto add_2 = Add(get_dimension_size, get_dimension_size_2); + EXPECT_TRUE(IsConstant(add_2, &b)); + + TF_ASSERT_OK_AND_ASSIGN(auto value, + ComputeConstantScalar(client, add_2, &b)); + EXPECT_EQ(value, 2); } } diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 72ff1e74a47c8584cb5336c86a1c978c4637a902..6530007871ced1d0bbffe2b44ccc8cf9bddd79e1 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -21,11 +21,14 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -178,5 +181,54 @@ TEST_F(ConstantsTest, Token) { TF_ASSERT_OK(Execute(&builder, {}).status()); } +TEST_F(ConstantsTest, FullLike) { + XlaBuilder b(TestName()); + auto val1 = Iota(&b, F32, 3); + auto val2 = FullLike(val1, 10); + val1 + val2; + ComputeAndCompareR1(&b, {10, 11, 12}, {}, error_spec_); +} + +TEST_F(ConstantsTest, IllegalFullLikeOnTuple) { + XlaBuilder b(TestName()); + auto tuple = Tuple(&b, {Iota(&b, F32, 3), Iota(&b, F32, 1)}); + FullLike(tuple, 10); // Illegal; can't do FullLike on a tuple. + EXPECT_FALSE(b.Build().ok()); +} + +TEST_F(ConstantsTest, FullLikeScalar) { + XlaBuilder b(TestName()); + auto scalar1 = ConstantR0WithType(&b, F32, 1); + auto scalar2 = FullLike(scalar1, 2); + scalar1 - scalar2; + ComputeAndCompareR0(&b, -1, {}, error_spec_); +} + +class ConstantsHloTest : public HloTestBase {}; + +// TODO(b/121147351): Fails on GPU. Not clear if this is expected behavior. +XLA_TEST_F(ConstantsHloTest, DISABLED_ON_GPU(BitcastOfConstant)) { + const char* testcase = R"( + HloModule module, is_scheduled=true + + func { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT mul = s32[] add(lhs, rhs) + } + + ENTRY test { + constant.0 = s32[1]{0} constant({0}) + parameter.0 = s32[] parameter(0) + constant-as-scalar = s32[] bitcast(constant.0) + ROOT result = s32[] call(parameter.0, constant-as-scalar), to_apply=func + } + )"; + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); + auto param = LiteralUtil::CreateR0(1); + auto result = ExecuteNoHloPasses(std::move(module), {¶m}); + EXPECT_TRUE(LiteralTestUtil::Equal(param, result)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..dfbf0478e62713635446d11557367cfac6ab0dce --- /dev/null +++ b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc @@ -0,0 +1,178 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +string GetFloatDataType(bool use_bfloat16) { + return use_bfloat16 ? "bf16" : "f32"; +} + +struct BatchGroupedConvolution2DSpec { + int64 output_batch, window, window_dilation; + std::vector activation_dims; + std::vector kernel_dims; + std::vector output_dims; + std::vector activation_and_kernel_layout; + std::vector output_layout; +}; + +class BatchGroupedConvolution2DTest + : public HloTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> {}; + +static std::vector GetConv2DTestCases() { + std::vector config_set; + std::vector> config_options = { + {8, 5, 3, 2}, {4, 5, 5, 2}, {8, 7, 4, 128}, + {16, 20, 20, 256}, {256, 7, 5, 4}, {256, 6, 6, 4}, + {256, 8, 8, 512}, {64, 7, 7, 960}, {64, 14, 14, 576}}; + + for (auto option : config_options) { + int64 feature = option[3]; + int64 activation_size = option[1]; + int64 kernel_size = option[2]; + int64 batch = option[0]; + + BatchGroupedConvolution2DSpec config; + config.window_dilation = 1; + config.output_batch = feature; + config.window = kernel_size; + + config.activation_dims = {batch, activation_size, activation_size, feature}; + + config.kernel_dims = {batch, kernel_size, kernel_size, feature}; + + int64 output_space_size = 3 + activation_size - kernel_size; + config.output_dims = {output_space_size, output_space_size, feature, 1}; + + config.activation_and_kernel_layout = {0, 3, 1, 2}; + config.output_layout = {2, 3, 0, 1}; + config_set.push_back(config); + + BatchGroupedConvolution2DSpec different_layout_config = config; + different_layout_config.activation_and_kernel_layout = {3, 0, 1, 2}; + config_set.push_back(different_layout_config); + + // Add configurations for window dilation cases. + if (activation_size % 2 == 0 && activation_size == kernel_size) { + BatchGroupedConvolution2DSpec config; + config.window_dilation = 2; + config.output_batch = feature; + config.window = kernel_size / 2; + config.activation_dims = {batch, activation_size, activation_size, + feature}; + config.kernel_dims = {batch, kernel_size / 2, kernel_size / 2, feature}; + config.activation_and_kernel_layout = {0, 3, 1, 2}; + config.output_layout = {2, 3, 0, 1}; + + int64 output_space_size = 5; + config.output_dims = {output_space_size, output_space_size, feature, 1}; + + config_set.push_back(config); + + BatchGroupedConvolution2DSpec different_layout_config = config; + different_layout_config.activation_and_kernel_layout = {3, 0, 1, 2}; + config_set.push_back(different_layout_config); + } + } + + return config_set; +} + +string BatchGroupedConvolution2DTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& spec = ::testing::get<0>(data.param); + const string data_type = GetFloatDataType(::testing::get<1>(data.param)); + string str = absl::StrCat( + "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), + "_activation_layout_", + absl::StrJoin(spec.activation_and_kernel_layout, "_"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), data_type, "_output_layout_", + absl::StrJoin(spec.output_layout, "_")); + + // Test names are not allowed to contain the '-' character. + absl::c_replace(str, '-', 'n'); + return str; +} + +string BuildHloTextBatchGroupedConvolution2D( + const BatchGroupedConvolution2DSpec& spec, bool use_bfloat16) { + const string data_type = GetFloatDataType(use_bfloat16); + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv, is_scheduled=true + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d pad=1_%dx1_%d rhs_dilate=%dx%d}, dim_labels=f01b_i01o->01fb, + batch_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.activation_and_kernel_layout, ","), spec.window, + spec.window, spec.window_dilation, spec.window_dilation, + spec.window_dilation, spec.window_dilation, spec.output_batch); +} + +XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) { + const BatchGroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam()); + bool use_bfloat16 = ::testing::get<1>(GetParam()); + const string hlo_text = + BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16); + + EXPECT_TRUE(RunAndCompareNoHloPasses( + hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); +} + +INSTANTIATE_TEST_CASE_P( + BatchGroupedConvolution2DTestWithRandomIndices, + BatchGroupedConvolution2DTest, + ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), + ::testing::Bool()), + BatchGroupedConvolution2DTestDataToString); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 4a58a1ed66c438d1dd9561f4eb029b38d8c6cbdd..9db9f2563b636c4f929585eb13a9c7f809833eda 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -98,7 +98,7 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { precision.add_operand_precision(PrecisionConfig::HIGHEST); precision.add_operand_precision(PrecisionConfig::DEFAULT); Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1, - &precision); + /*batch_group_count=*/1, &precision); ComputeAndCompare(&builder, {}, error_spec_); } @@ -467,8 +467,8 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { // servers. The error message is missing the operator ++. template void iota_int_init_value(std::vector& values, int init_value) { - std::for_each(values.begin(), values.end(), - [&](T& value) { value = static_cast(init_value++); }); + absl::c_for_each(values, + [&](T& value) { value = static_cast(init_value++); }); } template diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 3622f2c1e84639baed13059b21b20609d1347da6..df005a67097bb8aaf070c57d1c51acd1909fee12 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -133,7 +133,9 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { // Reverse the minor-to-major order of the literal. Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); - literal_layout->mutable_minor_to_major()->SwapElements(0, 1); + // Swap the first and second elements. + *literal_layout->mutable_minor_to_major() = { + literal_layout->minor_to_major(1), literal_layout->minor_to_major(0)}; HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 738b6442354b01364278e3e3c713aa2cdb5cf47d..4687ed61a7de91bc1bce0efeadf1965ad7d52d55 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -54,11 +54,20 @@ void Add1ToValues(float* out, float** in) { out[2] = array[2] + 1; out[3] = array[3] + 1; } + +void F32TupleSwap(float** out, float** in) { + TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[0], sizeof(float)); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[1], sizeof(float)); + *out[0] = *in[1]; + *out[1] = *in[0]; +} + } // namespace REGISTER_CUSTOM_CALL_TARGET(R0F32Add2); REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum); REGISTER_CUSTOM_CALL_TARGET(Add1ToValues); +REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap); namespace xla { namespace { @@ -69,7 +78,7 @@ class CustomCallTest : public HloTestBase { Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2}); }; -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { +XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) { auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -84,7 +93,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { +XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) { auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -105,7 +114,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { +XLA_TEST_F(CustomCallTest, UsedInOtherComputations) { auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); @@ -129,7 +138,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { +XLA_TEST_F(CustomCallTest, InputAndOutputLayoutDiffer) { auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); @@ -151,7 +160,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { LiteralTestUtil::ExpectR2Equal({{2.f, 4.f}, {3.f, 5.f}}, result); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { +XLA_TEST_F(CustomCallTest, LayoutConstrained) { // The argument and result of the computation are set to different layouts, // but the custom call is layout constrained to a fixed operand and result // layout, so the correct result should be produced. @@ -163,8 +172,10 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { const Shape& r2f32_dim0_major = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); - b.AddInstruction(HloInstruction::CreateCustomCall( + auto custom_call = b.AddInstruction(HloInstruction::CreateCustomCall( r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major})); + b.AddInstruction( + custom_call->CloneWithNewOperands(r2f32_dim0_major, {custom_call})); module->AddEntryComputation(b.Build()); ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); @@ -173,7 +184,27 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); Literal result = ExecuteAndTransfer(std::move(module), {&argument}); - LiteralTestUtil::ExpectR2Equal({{2.f, 3.f}, {4.f, 5.f}}, result); + LiteralTestUtil::ExpectR2Equal({{3.f, 4.f}, {5.f, 6.f}}, result); +} + +XLA_TEST_F(CustomCallTest, TupleOutput) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT %custom-call = (f32[], f32[]) custom-call(f32[] %p0, f32[] %p1), custom_call_target="F32TupleSwap", operand_layout_constraints={f32[], f32[]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg0 = LiteralUtil::CreateR0(7.f); + Literal arg1 = LiteralUtil::CreateR0(42.f); + + Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0}); + Literal result = ExecuteAndTransfer(std::move(module), {&arg0, &arg1}); + EXPECT_EQ(result, expected); } class CustomCallClientAPITest : public ClientLibraryTestBase {}; diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index c5d8b663f4abe77e05ec213d2e4e075c260a8655..6ee2178a227a12b7baa933f036a44db8ec630a4c 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -19,12 +19,14 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -918,8 +920,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {1, 0}); - auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {1, 6}); + auto one = ConstantR0(&builder, 1); + auto zero = ConstantR0(&builder, 0); + auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -945,8 +948,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {0, 1}); - auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {6, 1}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -974,8 +978,9 @@ XLA_TEST_F(DotOperationTest, XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {0, 1}); - auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {6, 1}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); @@ -1001,8 +1006,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {1, 0}); - auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {1, 6}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); @@ -1033,8 +1039,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {0, 1}); - auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {6, 1}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); @@ -1065,8 +1072,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {0, 1}); - auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {6, 1}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); @@ -1089,8 +1097,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {1, 0}); - auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {1, 6}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -1113,8 +1122,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {1, 0}); - auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {1, 6}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -1147,5 +1157,192 @@ XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) { ComputeAndCompareR2(&builder, expected, {}, error_spec_); } + +using EinsumParamType = + std::tuple, std::vector, string>; +class EinsumTest : public DotOperationTest, + public ::testing::WithParamInterface {}; +XLA_TEST_P(EinsumTest, SimpleEinsumTest) { + XlaBuilder builder(TestName()); + auto x = AddParam( + MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam()))) + .ValueOrDie(), + &builder); + auto y = AddParam( + MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam()))) + .ValueOrDie(), + &builder); + Einsum(x, y, std::get<2>(GetParam())); + ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3}); +} + +std::vector GetEinsumTestCases() { + using v = std::vector; + using p = EinsumParamType; + std::vector

test_cases = { + p{v{5, 6}, v{6, 7}, "mk,kn->mn"}, + p{v{5, 6}, v{6, 7}, "mk,kn->nm"}, + p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn->nmB"}, + p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->nmB"}, + p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->Bnm"}, + p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"}, + p{v{5, 6}, v{6, 7}, "ab,cd->dcba"}, + p{v{6}, v{6, 7}, "b,bc->c"}, + p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc->ab"}, + p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba->ca"}, + p{v{77}, v{77}, "a,a->a"}, + p{v{77}, v{77, 55}, "a,ab->ba"}, + p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"}, + p{v{55}, v{}, "a,->a"}, + p{v{11, 111}, v{11}, "ab,a->ab"}, + p{v{16, 34}, v{16, 34}, "ab,ab->ab"}, + p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac->abc"}, + p{v{5, 19}, v{}, "ab,->ab"}, + }; + return test_cases; +} + +INSTANTIATE_TEST_CASE_P(Einsum, EinsumTest, + ::testing::ValuesIn(GetEinsumTestCases())); + +class DotOperationTextTest : public HloTestBase {}; + +XLA_TEST_F(DotOperationTextTest, DotReorderedDotDims) { + absl::string_view hlo_string = + R"( +HloModule ComplexDotMultipleNonContracting + +ENTRY %test { + %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0) + %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1) + ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DotReorderedDotDimsAndMultipleContracting) { + absl::string_view hlo_string = + R"( +HloModule ComplexDotMultipleNonContracting + +ENTRY %test { + %lhs = f32[7,5,17,10,13]{4,3,2,1,0} parameter(0) + %rhs = f32[7,9,10,13,6,5]{5,4,3,2,1,0} parameter(1) + ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={3,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={1,4}, rhs_contracting_dims={5,3} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DotWithNoDnums) { + absl::string_view hlo_string = + R"( +HloModule DotWithNoDnums + +ENTRY %test { + %lhs = f32[2,3]{1,0} parameter(0) + %rhs = f32[4,5]{1,0} parameter(1) + ROOT %dot = f32[2,3,4,5]{3,2,1,0} dot(%lhs, %rhs) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); +} + +XLA_TEST_F(DotOperationTextTest, Einsum) { + absl::string_view hlo_string = + R"( +HloModule Einsum + +ENTRY %test { + %lhs = f32[8,64,96]{2,1,0} parameter(0) + %rhs = f32[96,32,4]{2,1,0} parameter(1) + ROOT %dot = f32[8,64,32,4]{3,2,1,0} dot(%lhs, %rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + +XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_1) { + // Tests for a caching bug in the XLA CPU backend. + absl::string_view hlo_string = + R"( +HloModule CpuTiledDotEmitterCachingBug + +ENTRY main { + lhs = f32[20,40] parameter(0) + rhs_0 = f32[40,1] parameter(2) + rhs_1 = f32[1,40] parameter(1) + + dot_0 = f32[20,1] dot(lhs, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_1 = f32[20,1] dot(lhs, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + + ROOT result = f32[20,1] divide(dot_0, dot_1) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + +XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_2) { + // Tests for a caching bug in the XLA CPU backend. + absl::string_view hlo_string = + R"( +HloModule CpuTiledDotEmitterCachingBug + +ENTRY main { + lhs_0 = f32[20,40] parameter(0) + rhs_0 = f32[40,1] parameter(1) + lhs_1 = f32[1,40] parameter(2) + rhs_1 = f32[20,40] parameter(3) + + dot_0 = f32[20,1] dot(lhs_0, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_1 = f32[1,20] dot(lhs_1, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + + dot_0_reshaped = f32[20] reshape(dot_0) + dot_1_reshaped = f32[20] reshape(dot_1) + + ROOT result = f32[20] divide(dot_0_reshaped, dot_1_reshaped) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DISABLED_ON_CPU(GpuIntegerDotCodegen)) { + absl::string_view hlo_string = + R"( +HloModule SmallIntegerDot + +ENTRY SmallIntegerDot { + arg0 = s32[1,2,2] parameter(0) + arg1 = s32[1,2,1] parameter(1) + ROOT dot = s32[1,2,1] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DISABLED_ON_CPU(GpuTransposeOutput)) { + absl::string_view hlo_string = + R"( +HloModule TransposeOutput + +ENTRY TransposeOutput { + p0 = f32[32,32] parameter(0) + p1 = f32[32,64] parameter(1) + dot = f32[32,64] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0} + ROOT tr = f32[64,32] transpose(dot), dimensions={1,0} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 7501c6d957e7afe99b8c530e5f0d575f818367da..82e2db36143b2552472fedae701f32389a9be108 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -135,11 +135,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::unique_ptr start_data = CreateR0Parameter( + slice_starts[0], 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); - DynamicSlice(input, starts, slice_sizes); + DynamicSlice(input, absl::Span({starts}), slice_sizes); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -160,14 +160,23 @@ class DynamicSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::vector starts(2); + std::vector> start_data(2); + for (int i = 0; i < 2; ++i) { + start_data[i] = CreateR0Parameter( + slice_starts[i], i, "slice_starts", &builder, &starts[i]); + } + // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); + std::vector argument_ptrs; + absl::c_transform(start_data, std::back_inserter(argument_ptrs), + [](const std::unique_ptr& argument) { + return argument.get(); + }); + ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } template @@ -186,14 +195,22 @@ class DynamicSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::vector starts(3); + std::vector> start_data(3); + for (int i = 0; i < 3; ++i) { + start_data[i] = CreateR0Parameter( + slice_starts[i], i, "slice_starts", &builder, &starts[i]); + } // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); + std::vector argument_ptrs; + absl::c_transform(start_data, std::back_inserter(argument_ptrs), + [](const std::unique_ptr& argument) { + return argument.get(); + }); + ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } }; @@ -372,16 +389,12 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { .ValueOrDie()); XlaBuilder builder(TestName()); - // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_value); auto update = ConstantLiteral(&builder, update_value); - DynamicUpdateSlice(input, update, starts); + DynamicUpdateSlice(input, update, absl::Span({})); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_value, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_value, {}); } template @@ -405,12 +418,12 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::unique_ptr start_data = CreateR0Parameter( + slice_starts[0], 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); auto update = ConstantLiteral(&builder, update_values); - DynamicUpdateSlice(input, update, starts); + DynamicUpdateSlice(input, update, absl::Span({starts})); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -435,15 +448,23 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::vector starts(2); + std::vector> start_data(2); + for (int i = 0; i < 2; ++i) { + start_data[i] = CreateR0Parameter( + slice_starts[i], i, "slice_starts", &builder, &starts[i]); + } // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); auto update = ConstantLiteral(&builder, update_values); DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); + std::vector argument_ptrs; + absl::c_transform(start_data, std::back_inserter(argument_ptrs), + [](const std::unique_ptr& argument) { + return argument.get(); + }); + ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } template @@ -466,15 +487,24 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::vector starts(3); + std::vector> start_data(3); + for (int i = 0; i < 3; ++i) { + start_data[i] = CreateR0Parameter( + slice_starts[i], i, "slice_starts", &builder, &starts[i]); + } + // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); auto update = ConstantLiteral(&builder, update_values); DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); + std::vector argument_ptrs; + absl::c_transform(start_data, std::back_inserter(argument_ptrs), + [](const std::unique_ptr& argument) { + return argument.get(); + }); + ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } template @@ -518,8 +548,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaOp update; std::unique_ptr update_data = CreateR3Parameter( update_values, 1, "update_values", &builder, &update); - auto starts = ConstantR1(&builder, {index, 0, 0}); - DynamicUpdateSlice(input, update, starts); + auto constant_index = ConstantR0(&builder, index); + auto zero = ConstantR0(&builder, 0); + DynamicUpdateSlice(input, update, {constant_index, zero, zero}); // Run computation and compare against expected values. ComputeAndCompareR3(&builder, expected_values, @@ -720,46 +751,55 @@ void BM_DynamicSlice(int num_iters) { {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); auto input = ConstantLiteral(&builder, input_literal); + auto stream = + client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); + // Create dynamic slice start indices as a parameter: shape [4] - auto start_indices_shape = ShapeUtil::MakeShape(S32, {4}); - auto start_indices = - Parameter(&builder, 0, start_indices_shape, "start_indices"); + auto start_indices_shape = ShapeUtil::MakeShape(S32, {}); + std::vector start_indices(4); + std::vector shaped_buffers; + std::vector host_shapes(4); + for (int i = 0; i < 4; ++i) { + start_indices[i] = + Parameter(&builder, i, start_indices_shape, "start_indices"); + auto start_index_literal = LiteralUtil::CreateR0(i + 1); + // Initialize and transfer parameter buffer. + shaped_buffers.emplace_back( + client->backend() + .transfer_manager() + ->AllocateScopedShapedBuffer(start_indices_shape, &allocator, + /*device_ordinal=*/0) + .ConsumeValueOrDie()); + host_shapes[i] = &shaped_buffers[i].on_host_shape(); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + stream.get(), start_index_literal, shaped_buffers[i])); + } + // Add DynamicSlice op to the computatation. DynamicSlice(input, start_indices, {1, 1, 1, 1}); auto computation = builder.Build().ConsumeValueOrDie(); - // Initialize and transfer parameter buffer. - auto buffer = client->backend() - .transfer_manager() - ->AllocateScopedShapedBuffer( - start_indices_shape, &allocator, /*device_ordinal=*/0) - .ConsumeValueOrDie(); - - auto start_indices_literal = LiteralUtil::CreateR1({0, 1, 2, 3}); - auto stream = - client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - stream.get(), start_indices_literal, buffer)); - std::unique_ptr executable = - client - ->Compile(computation, {&buffer.on_host_shape()}, - ExecutableBuildOptions()) + client->Compile(computation, host_shapes, ExecutableBuildOptions()) .ConsumeValueOrDie(); // Run some warm-up executions. ExecutableRunOptions options; options.set_allocator(&allocator); const int kWarmups = 2; + std::vector shaped_buffer_ptrs; + absl::c_transform(shaped_buffers, std::back_inserter(shaped_buffer_ptrs), + [](const ScopedShapedBuffer& buffer) { return &buffer; }); + for (int i = 0; i < kWarmups; ++i) { - auto result = executable->Run({&buffer}, options); + auto result = executable->Run(shaped_buffer_ptrs, options); ASSERT_TRUE(result.ok()); } // Run benchmark. tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = executable->Run({&buffer}, options); + auto result = executable->Run(shaped_buffer_ptrs, options); ASSERT_TRUE(result.ok()); } } diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index c84973e17b234c24c84f02a369ce0185f5772cca..b961e6102692cb3b90976d621c62cb4cf18a9b6b 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include "absl/base/casts.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -21,66 +23,166 @@ limitations under the License. namespace xla { namespace { + class ExhaustiveF32ElementwiseOpTest : public ClientLibraryTestBase, public ::testing::WithParamInterface> { protected: - ErrorSpec error_spec_{0.0001, 0.0001, /*relaxed_nans=*/true}; + ErrorSpec error_spec_{0.0001, 0.0001}; + + bool IsClose(float expected, float actual) { + float abs_err = std::abs(expected - actual); + float rel_err = abs_err / std::abs(expected); + return abs_err < error_spec_.abs || rel_err < error_spec_.rel || + (std::isnan(expected) && std::isnan(actual)) || + (std::isinf(expected) && std::isinf(actual) && + (expected > 0) == (actual > 0)); + } template void ExhaustivelyTestF32Op(EnqueueOpTy enqueue_op, float (*evaluate_op)(float), std::pair known_incorrect_range) { + SetFastMathDisabled(true); + int64 begin, end; std::tie(begin, end) = GetParam(); int64 input_size = end - begin; + + if (begin >= known_incorrect_range.first && + end <= known_incorrect_range.second) { + LOG(INFO) << absl::StreamFormat( + "Skipping this shard, as the range under test, [%d, %d), falls " + "entirely within the known-incorrect range [%d, %d).", + begin, end, known_incorrect_range.first, + known_incorrect_range.second); + return; + } + LOG(INFO) << "Checking range [" << begin << ", " << end << ")"; XlaBuilder builder(TestName()); - Literal input_literal = - LiteralUtil::CreateFromDimensions(F32, {input_size}); - for (int64 i = begin; i < end; i++) { + auto ith_input_elem = [&](int64 i) -> float { + i += begin; + // If the operation is known to be buggy on a specific input clamp that + // input to 0 under the assumption that the op is at least correct on 0. if (i >= known_incorrect_range.first && i < known_incorrect_range.second) { - // If the operation is known to be buggy on a specific input clamp that - // input to 0 under the assumption that the op is at least correct on 0. - input_literal.Set({i - begin}, 0.0f); - } else { - input_literal.Set({i - begin}, absl::bit_cast(i)); + return 0; } - } - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(input_literal)); + return absl::bit_cast(i); + }; + Literal input_literal = + LiteralUtil::CreateFromDimensions(F32, {input_size}); + absl::Span input_arr = input_literal.data(); + for (int64 i = 0; i < input_size; i++) { + input_arr[i] = ith_input_elem(i); + } auto input = Parameter(&builder, 0, input_literal.shape(), "input"); enqueue_op(&builder, input); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build()); + + // Build and run the computation using the LocalClient API, rather than the + // plain Client API, which is used by ClientLibraryTestBase. This is + // because the plain Client API results does more memcpys to/from Literals, + // and that's slow given that we're touching a lot of data here. + // + // Copy debug options from ClientLibraryTestBase. In particular, we're + // interested in disabling constant folding. + ExecutableBuildOptions build_opts; + *build_opts.mutable_debug_options() = *mutable_debug_options(); + TF_ASSERT_OK_AND_ASSIGN( + auto executable, + client_->Compile(comp, {&input_literal.shape()}, build_opts)); + + TF_ASSERT_OK_AND_ASSIGN( + ScopedShapedBuffer input_data, + client_->LiteralToShapedBuffer(input_literal, /*device_ordinal=*/0)); + + ExecutableRunOptions run_opts; + run_opts.set_allocator(client_->backend().memory_allocator()); + run_opts.set_intra_op_thread_pool( + client_->backend().eigen_intra_op_thread_pool_device()); + TF_ASSERT_OK_AND_ASSIGN(ScopedShapedBuffer result, + executable->Run({&input_data}, run_opts)); + + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + client_->ShapedBufferToLiteral(result)); + + // We essentially reimplement LiteralTestUtil::Near here because + // a) this streamlined implementation is much faster, and + // b) we can print out better error messages (namely, we can print out + // which floating-point value input failed, while LiteralTestUtil::Near + // can only print out the input index that failed). + // c) we need special handling of certain inputs. For example, we say that + // a denormal input has multiple correct outputs (namely, f(x) and f(0)) + // and just needs to be close to one of them. + absl::Span result_arr = result_literal.data(); + ASSERT_EQ(result_arr.size(), input_arr.size()); + int64 mismatches = 0; + // Hoisting this out of the loop is a nice speedup on shards that have many + // denormals. + const float expected_at_zero = evaluate_op(0); + for (int64 i = 0; i < input_arr.size(); ++i) { + float input = ith_input_elem(i); + float actual = result_arr[i]; + float expected = evaluate_op(input); + if (IsClose(expected, actual)) { + continue; + } - std::vector expected_result; - expected_result.reserve(input_size); - for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(evaluate_op(input_literal.Get({i}))); - } + constexpr int64 kMaxMismatchesPrinted = 1000; + if (std::fpclassify(input) == FP_SUBNORMAL) { + // For denormal inputs, we accept answers that are close to either + // - evaluate_op(input) OR + // - evaluate_op(0). + if (IsClose(expected_at_zero, actual)) { + continue; + } + ++mismatches; + if (mismatches < kMaxMismatchesPrinted || VLOG_IS_ON(2)) { + // Use %0.9g because that's guaranteed to print an f32 to full + // precision. + LOG(ERROR) << absl::StreamFormat( + "Mismatch on denormal value %0.9g (0x%08x). Expected either " + "%0.9g (0x%08x) (evaluated at true value) or %0.9g (0x%08x) " + "(evaluated at zero), but got %0.9g (0x%08x).", + input, absl::bit_cast(input), // + expected, absl::bit_cast(expected), // + expected_at_zero, absl::bit_cast(expected_at_zero), + actual, absl::bit_cast(actual)); + } + } else { + mismatches++; + if (mismatches < kMaxMismatchesPrinted || VLOG_IS_ON(2)) { + LOG(ERROR) << absl::StreamFormat( + "Mismatch on %0.9g (0x%08x). Expected %0.9g (0x%08x), but got " + "%0.9g (0x%08x).", + input, absl::bit_cast(input), // + expected, absl::bit_cast(expected), // + actual, absl::bit_cast(actual)); + } + } - ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, - error_spec_); + if (mismatches == kMaxMismatchesPrinted && !VLOG_IS_ON(2)) { + LOG(ERROR) << "Not printing any more mismatches; pass " + "--vmodule=exhaustive_f32_elementwise_op_test=2 to see " + "all of them."; + } + } + EXPECT_EQ(mismatches, 0); } }; XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { -#ifdef XLA_TEST_BACKEND_CPU - // TODO(b/73141998): The vectorized Log implementation gives results outside - // our error spec in this range (these numbers are bitwise representations of - // floats expressed as a zero extended int64). - std::pair known_incorrect_range = {1, 8388608}; -#else - std::pair known_incorrect_range = {0, 0}; +#if !defined(XLA_TEST_BACKEND_CPU) && !defined(XLA_TEST_BACKEND_GPU) + error_spec_ = ErrorSpec{0.001, 0.001}; #endif - ExhaustivelyTestF32Op( [](XlaBuilder* builder, const XlaOp& input) { Log(input); }, std::log, - known_incorrect_range); + /*known_incorrect_range=*/{0, 0}); } XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { @@ -105,6 +207,18 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, TanhF32) { /*known_incorrect_range=*/{0, 0}); } +XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ErfF32) { + ExhaustivelyTestF32Op( + [](XlaBuilder* builder, const XlaOp& input) { Erf(input); }, std::erf, + /*known_incorrect_range=*/{0, 0}); +} + +XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ErfcF32) { + ExhaustivelyTestF32Op( + [](XlaBuilder* builder, const XlaOp& input) { Erfc(input); }, std::erfc, + /*known_incorrect_range=*/{0, 0}); +} + std::vector> CreateExhaustiveParameters() { // We break up the 2^32-element space into small'ish chunks to keep peak // memory usage low. diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc index dcb469087e0064d17ce3b04fdeaf0b6136069a55..1b0bebe2d03a9a153cd0c80329ed0c49c91333a3 100644 --- a/tensorflow/compiler/xla/tests/filecheck.cc +++ b/tensorflow/compiler/xla/tests/filecheck.cc @@ -48,7 +48,7 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { tensorflow::SubProcess file_check_process; file_check_process.SetProgram(file_check_path, - {file_check_path, pattern_path}); + {file_check_path, "-v", pattern_path}); file_check_process.SetChannelAction(tensorflow::CHAN_STDIN, tensorflow::ACTION_PIPE); file_check_process.SetChannelAction(tensorflow::CHAN_STDERR, @@ -71,9 +71,7 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { LOG(WARNING) << "NOTE: FileCheck binary does not exist!"; } - LOG(WARNING) << "FileCheck error: " << standard_error; - LOG(WARNING) << "FileCheck input was:"; - XLA_LOG_LINES(tensorflow::WARNING, input); + LOG(WARNING) << "FileCheck error:\n" << standard_error; LOG(WARNING) << "FileCheck pattern was:"; XLA_LOG_LINES(tensorflow::WARNING, pattern); } else if (!standard_error.empty()) { diff --git a/tensorflow/compiler/xla/tests/fmax_fmin_test.cc b/tensorflow/compiler/xla/tests/fmax_fmin_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7423ac0bcdb0bc305ee384fb98bd17413404ecef --- /dev/null +++ b/tensorflow/compiler/xla/tests/fmax_fmin_test.cc @@ -0,0 +1,88 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class FmaxSimpleTest : public ClientLibraryTestBase {}; + +TEST_F(FmaxSimpleTest, FmaxTenValues) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + auto x = ConstantR1( + &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); + auto y = ConstantR1( + &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); + Max(x, y); + + std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, 9.0}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(FmaxSimpleTest, FmaxEdgeCases) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + XlaOp param0, param1; + std::unique_ptr param0_data = CreateR1Parameter( + {INFINITY, INFINITY, INFINITY, -INFINITY, INFINITY, -INFINITY, NAN, + INFINITY, -INFINITY, NAN}, + /*parameter_number=*/0, /*name=*/"param0", + /*builder=*/&builder, /*data_handle=*/¶m0); + std::unique_ptr param1_data = CreateR1Parameter( + {INFINITY, -INFINITY, NAN, NAN, -4.0, -5.0, -6.0, 7.0, 8.0, 9.0}, + /*parameter_number=*/1, /*name=*/"param1", + /*builder=*/&builder, /*data_handle=*/¶m1); + + Max(param0, param1); + std::vector expected = {INFINITY, INFINITY, NAN, NAN, INFINITY, + -5, NAN, INFINITY, 8, NAN}; + ComputeAndCompareR1(&builder, expected, + {param0_data.get(), param1_data.get()}, + ErrorSpec(0.0001)); +} + +TEST_F(FmaxSimpleTest, FminEdgeCases) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + XlaOp param0, param1; + std::unique_ptr param0_data = CreateR1Parameter( + {INFINITY, INFINITY, INFINITY, -INFINITY, INFINITY, -INFINITY, NAN, + INFINITY, -INFINITY, NAN}, + /*parameter_number=*/0, /*name=*/"param0", + /*builder=*/&builder, /*data_handle=*/¶m0); + std::unique_ptr param1_data = CreateR1Parameter( + {INFINITY, -INFINITY, NAN, NAN, -4.0, -5.0, -6.0, 7.0, 8.0, 9.0}, + /*parameter_number=*/1, /*name=*/"param1", + /*builder=*/&builder, /*data_handle=*/¶m1); + + Min(param0, param1); + std::vector expected = {INFINITY, -INFINITY, NAN, NAN, -4, + -INFINITY, NAN, 7, -INFINITY, NAN}; + ComputeAndCompareR1(&builder, expected, + {param0_data.get(), param1_data.get()}, + ErrorSpec(0.0001)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc deleted file mode 100644 index c5bbbe778df15d63a2586bd6291a7a33fc82aa52..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace { - -class FmaxSimpleTest : public ClientLibraryTestBase {}; - -TEST_F(FmaxSimpleTest, FmaxTenValues) { - XlaBuilder builder(TestName()); - auto x = ConstantR1( - &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); - auto y = ConstantR1( - &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); - Max(x, y); - - std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, - 5.0, 6.0, 7.0, 8.0, 9.0}; - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index d1fddf9d6b494a822610e41307fa103dc90bdef3..2178c9b3f3d39ac034c59585c6836d2bc59162c1 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -523,10 +523,10 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto dynamic_slice2 = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ShapeUtil::MakeShape(S32, {2}), const0, const1, {2})); + ShapeUtil::MakeShape(S32, {2}), const0, {const1}, {2})); auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, dynamic_slice2)); hlo_module->AddEntryComputation(builder.Build()) diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index daa89398a697af9149797d621c3bdca80a00aedd..d65b67a535d43553a3a94f76482ad4618f9b8aab 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -600,7 +600,9 @@ ENTRY main { class GatherClientLibraryTest : public ClientLibraryTestBase {}; -XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { +// Disabled on interpreter since ExectuteAsyncOnStream is not supported. +XLA_TEST_F(GatherClientLibraryTest, + DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(Basic))) { // We create this HLO, but using the XlaBuilder API. // // ENTRY main { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 989a7c705a8254f99e5cc0e97dfde5942f146964..0151981ef16aabe9e363bc4d7f9ba96d4a1f170f 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -139,7 +139,8 @@ std::unique_ptr HloTestBase::CreateNewVerifiedModule( const string& name) { return absl::make_unique( name, GetModuleConfigForTest(), verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_); + allow_mixed_precision_in_hlo_verifier_, + backend().compiler()->ShapeSizeBytesFunction()); } StatusOr> @@ -147,7 +148,8 @@ HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, const HloModuleConfig& config) { auto module = absl::make_unique( TestName(), config, verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_); + allow_mixed_precision_in_hlo_verifier_, + backend().compiler()->ShapeSizeBytesFunction()); TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); TF_RETURN_IF_ERROR(module->Verify()); return std::move(module); @@ -181,6 +183,7 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() { // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + debug_options.set_xla_hlo_evaluator_use_fast_path(true); return debug_options; } @@ -202,6 +205,17 @@ Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } +StatusOr> HloTestBase::ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64 num_replicas) { + HloRunner::ReplicatedExecuteOptions options; + options.num_replicas = num_replicas; + for (auto argument : arguments) { + options.arguments.push_back(argument); + } + return test_runner_.ExecuteReplicated(std::move(module), options); +} + StatusOr> HloTestBase::MakeReferenceModule( const HloModule& test_module, const std::function& reference_preprocessor) { @@ -310,7 +324,10 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( reference_preprocessor); } -::testing::AssertionResult HloTestBase::Run(string_view hlo_string) { +::testing::AssertionResult HloTestBase::Run(string_view hlo_string, + bool run_hlo_passes, + ExecutionProfile* profile, + string backend_config) { auto module_or_status = HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); if (!module_or_status.ok()) { @@ -318,19 +335,108 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( << "Error while parsing HLO text format: " << module_or_status.status().ToString(); } + + std::unique_ptr module = std::move(module_or_status.ValueOrDie()); const auto& fake_arguments = - MakeFakeArguments(module_or_status.ValueOrDie().get()) - .ConsumeValueOrDie(); + MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), [](const Literal& literal) { return const_cast(&literal); }); - return test_runner_ - .Execute(std::move(module_or_status.ValueOrDie()), - fake_argument_ptrs, /*run_hlo_passes=*/true) - .ok() + + if (profile != nullptr) { + // We have to enable HLO profiling since otherwise currently the + // ExecutionProfile is not correct. + // + // TODO(b/119432044): Fix collection of the ExecutionProfile + // so that this is not necessary. + HloModuleConfig config = module->config(); + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_hlo_profile(true); + config.set_debug_options(debug_options); + module->set_config(config); + } + + if (!backend_config.empty()) { + // Set backend configuration if it is given. + HloInstruction* instruction = + module->entry_computation()->root_instruction(); + instruction->set_raw_backend_config_string(backend_config); + } + + // return ::testing::AssertionSuccess(); + auto output = test_runner_.Execute(std::move(module), fake_argument_ptrs, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); + + return output.ok() ? ::testing::AssertionSuccess() - : ::testing::AssertionFailure(); + : ::testing::AssertionFailure() << output.status().error_message(); +} + +::testing::AssertionResult HloTestBase::RunMultipleTimes( + string_view hlo_string, bool run_hlo_passes, + std::vector* profiles, string backend_config) { + int n = profiles->size(); + std::vector> fake_argument_ptrs(n); + std::vector> fake_arguments(n); + std::vector> executables(n); + + for (int i = 0; i < n; ++i) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + std::unique_ptr module = + std::move(module_or_status.ValueOrDie()); + + fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie(); + absl::c_transform( + fake_arguments[i], std::back_inserter(fake_argument_ptrs[i]), + [](const Literal& literal) { return const_cast(&literal); }); + + if (profiles != nullptr) { + // We have to enable HLO profiling since otherwise currently the + // ExecutionProfile is not correct. + // + // TODO(b/119432044): Fix collection of the ExecutionProfile + // so that this is not necessary. + HloModuleConfig config = module->config(); + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_hlo_profile(true); + config.set_debug_options(debug_options); + module->set_config(config); + } + + if (!backend_config.empty()) { + // Set backend configuration if it is given. + HloInstruction* instruction = + module->entry_computation()->root_instruction(); + instruction->set_raw_backend_config_string(backend_config); + } + + auto executable = + test_runner_.CreateExecutable(std::move(module), run_hlo_passes); + if (!executable.ok()) { + return ::testing::AssertionFailure() + << executable.status().error_message(); + } + executables[i] = std::move(executable.ValueOrDie()); + } + + for (int i = 0; i < n; ++i) { + auto output = + test_runner_.Execute(std::move(executables[i]), fake_argument_ptrs[i], + /*profile=*/&((*profiles)[i])); + if (!output.ok()) { + return ::testing::AssertionFailure() << output.status().error_message(); + } + } + + return ::testing::AssertionSuccess(); } ::testing::AssertionResult HloTestBase::RunAndCompareFromFile( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 1d1e7f437296a7493ef7da07039fcf6d273f35bc..3c2bcbb5df5ce94dd37f63d0c0e609f3ad2b60aa 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -46,10 +46,12 @@ class VerifiedHloModule : public HloModule { public: VerifiedHloModule(const string& name, const HloModuleConfig& config, bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function shape_size_function) : HloModule(name, config), - verifier_(verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} + verifier_( + verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier, + /*instruction_can_change_layout_func=*/{}, shape_size_function) {} ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } @@ -171,6 +173,11 @@ class HloTestBase : public ::testing::Test { Literal ExecuteAndTransfer(std::unique_ptr module, absl::Span arguments); + // Executes the given module on multiple replicas. + StatusOr> ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64 num_replicas); + // Executes the given hlo module on two backends and compares results. // // 'arguments': the input of the hlo module. @@ -219,8 +226,14 @@ class HloTestBase : public ::testing::Test { const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; - ::testing::AssertionResult Run(const absl::string_view hlo_string) - TF_MUST_USE_RESULT; + ::testing::AssertionResult Run(const absl::string_view hlo_string, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr, + string backend_config = "") TF_MUST_USE_RESULT; + ::testing::AssertionResult RunMultipleTimes( + const absl::string_view hlo_string, bool run_hlo_passes, + std::vector* profiles, + string backend_config = "") TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareFromFile( const string& filename, const absl::optional& error, const std::function& reference_preprocessor = nullptr) diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index 65205f53ddc582ae477d67705f161fef1e31b857..37b2c635eebe57590e1ba73c62f015ccf399b548 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -80,7 +80,7 @@ TEST_P(IotaR2Test, DoIt) { } INSTANTIATE_TEST_CASE_P(IotaR2TestInstantiation, IotaR2Test, - ::testing::Combine(::testing::Values(F32, S32), + ::testing::Combine(::testing::Values(F32, S32, BF16), ::testing::Range(/*start=*/10, /*end=*/1001, /*step=*/10), diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 554eb24d44168caa7d7252015e3d99f2d567df9b..a2fd6070731943f15c773265f428b16f520d02ee 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -86,7 +86,7 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, /* static */ ::testing::AssertionResult LiteralTestUtil::Near( const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error_spec, bool detailed_message) { + const ErrorSpec& error_spec, absl::optional detailed_message) { return StatusToAssertion(literal_comparison::Near( expected, actual, error_spec, detailed_message, &OnMiscompare)); } @@ -97,7 +97,8 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, if (error.has_value()) { VLOG(1) << "Expects near"; return StatusToAssertion(literal_comparison::Near( - expected, actual, *error, /*detailed_message=*/false, &OnMiscompare)); + expected, actual, *error, /*detailed_message=*/absl::nullopt, + &OnMiscompare)); } VLOG(1) << "Expects equal"; return StatusToAssertion(literal_comparison::Equal(expected, actual)); diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 43cca91f64b2c0fbfde5054a361cf0f95302c23d..d7cf9bed98a3eb7479b6deb6838dc388a0869360 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -93,7 +93,7 @@ class LiteralTestUtil { static ::testing::AssertionResult Near( const LiteralSlice& expected, const LiteralSlice& actual, const ErrorSpec& error_spec, - bool detailed_message = false) TF_MUST_USE_RESULT; + absl::optional detailed_message = absl::nullopt) TF_MUST_USE_RESULT; // Asserts the given literal are within the given error bound of the given // expected values. Only supported for floating point values. diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index b6f9b8156b51144e4f74d285b1e4111d098f13c2..ea9b3037cf482e41238413179888f125822d161c 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -89,11 +89,11 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { Literal literal = Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", literal.ToString()); + EXPECT_EQ("f32[] 2", literal.ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", literal.ToString()); + EXPECT_EQ("f32[] 4", literal.ToString()); } else if (result.find("mismatches") != string::npos) { - EXPECT_EQ("true", literal.ToString()); + EXPECT_EQ("pred[] true", literal.ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } @@ -105,9 +105,9 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { auto actual = LiteralUtil::CreateR1({4, 5, 6}); ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual); EXPECT_THAT(result.message(), - ::testing::HasSubstr("Expected literal:\n{1, 2, 3}")); + ::testing::HasSubstr("Expected literal:\ns32[3] {1, 2, 3}")); EXPECT_THAT(result.message(), - ::testing::HasSubstr("Actual literal:\n{4, 5, 6}")); + ::testing::HasSubstr("Actual literal:\ns32[3] {4, 5, 6}")); } TEST(LiteralTestUtilTest, NearComparatorR1) { diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index a99b43f4690b3063f76e2cda1e58c9b4ba9a1df4..96527886b718bc1ea4ce8cc2d7dbeb2e3ef1d1eb 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -205,7 +205,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); - EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_TRUE(result.on_host_shape().IsTuple()); EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); Literal result_literal = ShapedBufferToLiteral(result); @@ -233,7 +233,7 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); - EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_TRUE(result.on_host_shape().IsTuple()); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); Literal result_literal = ShapedBufferToLiteral(result); @@ -311,7 +311,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer}); - EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_TRUE(result.on_host_shape().IsTuple()); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); Literal result_literal = ShapedBufferToLiteral(result); @@ -842,7 +842,8 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { LiteralUtil::CreateR0(123456789000LL)})); } -XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { +// Disabled on interpreter backend since infeed HLO is unsupported. +XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedTest)) { XlaBuilder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {3}); auto in = Infeed(&builder, shape); @@ -867,7 +868,8 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } -XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { +// Disabled on interpreter backend since infeed/outfeed HLOs are unsupported. +XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) { XlaBuilder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {3}); auto in = Infeed(&builder, shape); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 3f5135438fc59bea98527b1be30ee49339edd455..1fd9cb055c0bebc0f31496eb82f53a7b7a6cbfba 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -208,9 +208,7 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation } )"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned( LiteralUtil::MakeTupleOwned( LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), @@ -241,9 +239,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { const = f32[4] constant({0, 0, 0, 0}) ROOT select = f32[4] select(gte0, gte1, const) })"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0, -1.0}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, result); @@ -273,9 +269,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { p1 = f32[3] parameter(0) ROOT map = f32[3] map(p1), to_apply=map_computation })"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, result); @@ -315,9 +309,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -346,9 +338,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -378,9 +368,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -410,9 +398,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -443,9 +429,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -478,9 +462,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -513,9 +495,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); auto init1 = LiteralUtil::CreateR0(5); @@ -549,9 +529,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3( {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); diff --git a/tensorflow/compiler/xla/tests/plugin.bzl b/tensorflow/compiler/xla/tests/plugin.bzl index 8a5d91363b619c6b214a96ad96e92742e3052541..107869fe59d43d0a9a3e2b14af2c09e4906d9f15 100644 --- a/tensorflow/compiler/xla/tests/plugin.bzl +++ b/tensorflow/compiler/xla/tests/plugin.bzl @@ -33,4 +33,3 @@ # } plugins = {} - diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 8f2c26f0eea9c7a3b33cd77e5977924c1659535a..e49bcf26bd6e50f8fb36c86f217907b5d4901eae 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -80,7 +80,9 @@ XLA_TEST_F(PrngTest, LargeU01) { UniformTest(0, 1, {0x100, 0x100}); } XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest(5, 24, {12}); } // TODO(b/71543667): Fix Rng ops on LLVM backends. -XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) { +// TODO(b/122047800): Interpreter does not support BF16 for RNG ops. +XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER( + DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests)))) { for (int64 seed = 0; seed < 100; ++seed) { // The largest negative number smaller than zero in bf16 that's not // denormalized. @@ -103,7 +105,9 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) { } // TODO(b/71543667): Fix Rng ops on LLVM backends. -XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) { +// TODO(b/122047800): Interpreter does not support BF16 for RNG ops. +XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU( + DISABLED_ON_CPU(ScalarBF16CountTests)))) { // There are 3 BF16 values in the range of [32.25, 33): 32.25, 32.5, 32.75, // they should get similar counts. bfloat16 low = static_cast(32.25); @@ -276,6 +280,39 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6)); } +// This test verifies that the two RNG instructions with the same parameters in +// the same HloComputation produces different values. +XLA_TEST_F(PrngTest, DifferentValuesForIdenticalRngNodesInSameComputation) { + // Build a U[0,1) computation. + auto build_computation = [this]() { + XlaBuilder builder(TestName()); + auto a = RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, 100), + ShapeUtil::MakeShape(S32, {10})); + auto b = RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, 100), + ShapeUtil::MakeShape(S32, {10})); + Tuple(&builder, {a, b}); + return builder.Build(); + }; + + ExecutionOptions execution_options = execution_options_; + execution_options.set_seed(42); + + Literal result_tuple; + { + TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); + TF_ASSERT_OK_AND_ASSIGN( + result_tuple, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options)); + } + + auto results = result_tuple.DecomposeTuple(); + ASSERT_EQ(results.size(), 2); + + EXPECT_FALSE(LiteralTestUtil::Equal(results[0], results[1])); +} + XLA_TEST_F(PrngTest, TenValuesN01) { XlaBuilder builder(TestName()); RngNormal(ConstantR0(&builder, 0), ConstantR0(&builder, 1), diff --git a/tensorflow/compiler/xla/tests/ptxas_bug_120501638.cc b/tensorflow/compiler/xla/tests/ptxas_bug_120501638.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e5d7db97e88936e7336ed02a5c7a1171254b0cf --- /dev/null +++ b/tensorflow/compiler/xla/tests/ptxas_bug_120501638.cc @@ -0,0 +1,82 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class PtxasBugTest : public HloTestBase {}; + +// Checks for a bug in ptxas, tracked as Google bug 120501638, and nvidia bug +// 2459377. We never received an explanation of what exactly was going wrong +// here in ptxas. Known-bad in ptxas 10.0.145, known-good in ptxas 10.0.249. +TEST_F(PtxasBugTest, DoIt) { + const char* const kModuleStr = R"( +HloModule test + +add_F32.14 { + lhs.15 = f32[] parameter(0) + rhs.16 = f32[] parameter(1) + ROOT add.17 = f32[] add(lhs.15, rhs.16) +} + +ENTRY testcase { + arg0.1 = f32[2,5,2]{2,1,0} parameter(0) + reshape.2 = f32[2,5,2]{2,1,0} reshape(arg0.1) + constant.3 = f32[] constant(0) + pad.4 = f32[2,6,2]{2,1,0} pad(reshape.2, constant.3), padding=0_0x0_1x0_0 + reshape.5 = f32[2,3,2,2]{3,2,1,0} reshape(pad.4) + transpose.6 = f32[2,2,3,2]{3,0,2,1} transpose(reshape.5), dimensions={2,0,1,3} + reshape.7 = f32[4,3,2]{2,1,0} reshape(transpose.6) + reshape.8 = f32[4,1,3,2]{3,2,1,0} reshape(reshape.7) + transpose.9 = f32[4,2,1,3]{1,3,2,0} transpose(reshape.8), dimensions={0,3,1,2} + convert.10 = f32[4,2,1,3]{1,3,2,0} convert(transpose.9) + constant.12 = f32[] constant(0) + pad.13 = f32[4,2,1,3]{3,2,1,0} pad(convert.10, constant.12), padding=0_0x0_0x0_0x0_0 + constant.11 = f32[] constant(0) + reduce-window.18 = f32[4,2,1,3]{3,2,1,0} reduce-window(pad.13, constant.11), + window={size=1x1x1x1}, to_apply=add_F32.14 + constant.19 = f32[] constant(1) + broadcast.20 = f32[4,2,1,3]{3,2,1,0} broadcast(constant.19), dimensions={} + divide.21 = f32[4,2,1,3]{3,2,1,0} divide(reduce-window.18, broadcast.20) + convert.22 = f32[4,2,1,3]{3,2,1,0} convert(divide.21) + transpose.23 = f32[4,1,3,2]{2,1,3,0} transpose(convert.22), dimensions={0,2,3,1} + reshape.24 = f32[4,3,2]{2,1,0} reshape(transpose.23) + reshape.25 = f32[2,2,3,2]{3,2,1,0} reshape(reshape.24) + transpose.26 = f32[2,3,2,2]{3,1,0,2} transpose(reshape.25), dimensions={1,2,0,3} + reshape.27 = f32[2,6,2]{2,1,0} reshape(transpose.26) + slice.28 = f32[2,5,2]{2,1,0} slice(reshape.27), slice={[0:2], [0:5], [0:2]} + reshape.29 = f32[2,5,2]{2,1,0} reshape(slice.28) + tuple.30 = (f32[2,5,2]{2,1,0}) tuple(reshape.29) + ROOT get-tuple-element.31 = f32[2,5,2]{2,1,0} get-tuple-element(tuple.30), index=0 +})"; + + // Create a module with the true-default flags, not the default-for-testing + // flags. In particular, true-default flags enable unrolling, whereas for + // testing we disable unrolling, and this bug doesn't trigger without + // unrolling. + HloModuleConfig config; + config.set_debug_options(DefaultDebugOptionsIgnoringFlags()); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01, 0.01})); +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index f80d29b9de440b11c36e8c9bc65d4a93353a6267..e2cf4c0be289b52d5cc581ea07752ed6e98da76f 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 22fe4a2670e2e0e1fedc45036a1ceec19f44e42e..30e2d24184a5d399e5e058a9c4a382f57e82866f 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -607,7 +607,10 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, Array4D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2], param.base_bounds[3]); - input.FillRandom(0.1f, 0.1f); + // Choose a prime iota length so that each window sees a unique set of + // values. (Technically, the requirement is that the iota length is + // relatively prime to all of the dimensions involved in the reduce-window.) + input.FillRepeatedIota(0, 137); Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; @@ -623,9 +626,9 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto reducer = param.reducer; - if (use_bfloat16() && Product(param.window_bounds) > 128) { - // To avoid numerical issues, force the reducer to be kMax for large bf16 - // windows. + if (use_bfloat16()) { + // To avoid numerical issues, force the reducer to be kMax for bf16 + // inputs. reducer = kMax; } @@ -949,16 +952,16 @@ struct R3ReduceWindowTestData { /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, @@ -1001,17 +1004,19 @@ TEST_P(R3ReduceWindowTest, DoIt) { const float kInitValue = 0.0f; Array3D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2]); - input.FillRandom(0.1f, 0.1f); + // Choose a prime iota length so that each window sees a unique set of values. + // (Technically, the requirement is that the iota length is relatively prime + // to all of the dimensions involved in the reduce-window.) + input.FillRepeatedIota(0, 137); Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); auto reducer = param.reducer; if (use_bfloat16()) { input_literal = LiteralUtil::ConvertF32ToBF16(input_literal); - if (Product(param.window_bounds) > 128) { - // To avoid numerical issues, force the reducer to be kMax for large bf16 - // windows. - reducer = kMax; - } + + // To avoid numerical issues, force the reducer to be kMax for bf16 + // inputs. + reducer = kMax; } XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input"); @@ -1527,6 +1532,25 @@ ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); } +XLA_TEST_F(HloTestBase, ReduceWindowS64) { + const string hlo_string = R"( +HloModule reduce-window + +%identity.pad_to_reduce_window (param0: s64[], param1: s64[]) -> s64[] { + %param0 = s64[] parameter(0) + ROOT %param1 = s64[] parameter(1) +} + +ENTRY %reduce-window (parameter.0: s64[81,8], parameter.1: s64[]) -> s64[82,8] { + %parameter.0 = s64[81,8]{1,0} parameter(0) + %parameter.1 = s64[] parameter(1) + ROOT %reduce-window = s64[82,8]{1,0} reduce-window(s64[81,8]{1,0} %parameter.0, s64[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); +} + XLA_TEST_F(HloTestBase, ReduceWindowF16) { const string hlo_string = R"( HloModule reduce-window diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index 7ca99a91635e85cd0888e59ecde31e47fec21844..80a6868485c9162d1cb0de24f0adf3f1c1d2503a 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -79,30 +79,28 @@ string PrependDisabledIfIndicated(const string& test_case_name, // heuristic to decide whether the test case should be disabled, and we // determine whether the test case should be disabled by resolving the (test // case name, test name) in a manifest file. -#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class, parent_id) \ - class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ - : public parent_class { \ - public: \ - GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ - \ - private: \ - virtual void TestBody(); \ - static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \ - GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)); \ - }; \ - \ - ::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)::test_info_ = \ - ::testing::internal::MakeAndRegisterTestInfo( \ - #test_case_name, \ - ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ - .c_str(), \ - nullptr, nullptr, \ - ::testing::internal::CodeLocation(__FILE__, __LINE__), (parent_id), \ - parent_class::SetUpTestCase, parent_class::TearDownTestCase, \ - new ::testing::internal::TestFactoryImpl); \ +#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class) \ + class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ + : public parent_class { \ + public: \ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ + \ + private: \ + virtual void TestBody(); \ + static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \ + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \ + test_name)); \ + }; \ + \ + ::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \ + test_name)::test_info_ = \ + ::testing::RegisterTest( \ + #test_case_name, \ + ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ + .c_str(), \ + nullptr, nullptr, __FILE__, __LINE__, []() -> parent_class* { \ + return new GTEST_TEST_CLASS_NAME_(test_case_name, test_name)(); \ + }); \ void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody() // This is identical to the TEST_F macro from "gtest", but it potentially @@ -111,9 +109,8 @@ string PrependDisabledIfIndicated(const string& test_case_name, // Per usual, you can see what tests are available via --gunit_list_tests and // choose to run tests that have been disabled via the manifest via // --gunit_also_run_disabled_tests. -#define XLA_TEST_F(test_fixture, test_name) \ - XLA_GTEST_TEST_(test_fixture, test_name, test_fixture, \ - ::testing::internal::GetTypeId()) +#define XLA_TEST_F(test_fixture, test_name) \ + XLA_GTEST_TEST_(test_fixture, test_name, test_fixture) // Likewise, this is identical to the TEST_P macro from "gtest", but // potentially disables the test based on the DISABLED_MANIFEST file. diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index eafa48ed7b8cf2bd67fe767ad36082661dbbd66e..67d2258928f75c078588c9425359f9468f4463ed 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -168,7 +169,7 @@ void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, StatusOr MakeFakeLiteralInternal(const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { std::vector elements; for (const Shape& element_shape : shape.tuple_shapes()) { TF_ASSIGN_OR_RETURN( @@ -237,6 +238,79 @@ StatusOr MakeFakeLiteralInternal(const Shape& shape, return std::move(literal); } +template +void PopulateWithRandomIntegralDataWithBounds(Literal* literal, + std::minstd_rand0* engine, + IntT min, IntT max) { + CHECK(engine != nullptr); + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + std::uniform_int_distribution generator(min, max); + for (IntT& value : literal->data()) { + value = generator(*engine); + } +} + +// Same as MakeFakeLiteralInternal but generates random numbers in the given +// range [min, max]. Currently this works only for INT types. +StatusOr MakeFakeLiteralInternalWithBounds(const Shape& shape, + std::minstd_rand0* engine, + int64 min, int64 max) { + if (shape.IsTuple()) { + std::vector elements; + for (const Shape& element_shape : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN( + Literal element, + MakeFakeLiteralInternalWithBounds(element_shape, engine, min, max)); + elements.push_back(std::move(element)); + } + return LiteralUtil::MakeTupleOwned(std::move(elements)); + } + if (engine == nullptr) { + return Literal::CreateFromShape(shape); + } + Literal literal(shape); + switch (shape.element_type()) { + case S8: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U8: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case S16: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U16: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case S32: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U32: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case S64: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U64: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + default: + return Unimplemented( + "Unsupported type for fake random literal generation with bounds: %s", + ShapeUtil::HumanString(shape)); + } + return std::move(literal); +} + enum class ConstantType { kUnknown, kZero, kOne }; // Return the constant type required by this computation, if known. @@ -274,16 +348,9 @@ bool NeedsInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -Literal MakeRandomIndex(absl::Span index_space, - std::minstd_rand0* engine) { - std::vector start_indices(index_space.size()); - if (engine != nullptr) { - for (int i = 0; i < index_space.size(); ++i) { - std::uniform_int_distribution generator(0, index_space[i]); - start_indices[i] = generator(*engine); - } - } - return LiteralUtil::CreateR1(start_indices); +Literal MakeRandomIndex(int64 index_bound, std::minstd_rand0* engine) { + std::uniform_int_distribution generator(0, index_bound); + return LiteralUtil::CreateR0(generator(*engine)); } // Use dataflow analysis on each parameter to see if there are uses that would @@ -300,8 +367,12 @@ std::vector FindConstrainedUses( HloInstruction* instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); const int64 op_num = use.operand_number; - if ((opcode == HloOpcode::kDynamicSlice && op_num == 1) || - (opcode == HloOpcode::kDynamicUpdateSlice && op_num == 2)) { + if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) || + (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) { + constrained_uses.push_back(instruction); + } else if ((opcode == HloOpcode::kGather || + opcode == HloOpcode::kScatter) && + op_num == 1) { constrained_uses.push_back(instruction); } else if (opcode == HloOpcode::kFusion) { const HloInstruction* const to_analyze = @@ -336,7 +407,7 @@ std::vector FindConstrainedUses( StatusOr CreateLiteralForConstrainedUses( const absl::Span constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { - std::vector index_space; + int64 index_bound = INT64_MAX; bool no_duplicates = false; bool needs_constant = false; ConstantType constant_type = ConstantType::kUnknown; @@ -348,19 +419,32 @@ StatusOr CreateLiteralForConstrainedUses( const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice ? use->shape() : use->operand(1)->shape(); - const int64 rank = ShapeUtil::Rank(indexed_shape); - if (!index_space.empty()) { - TF_RET_CHECK(rank == index_space.size()); - for (int64 i = 0; i < rank; ++i) { - index_space[i] = std::min( - index_space[i], ShapeUtil::GetDimension(indexed_shape, i) - - ShapeUtil::GetDimension(slice_shape, i)); + const int64 first_index = + Cast(use)->first_index_operand_number(); + for (int64 operand = first_index; operand < use->operand_count(); + ++operand) { + if (use->operand(operand) == ¶m) { + index_bound = std::min( + index_bound, + ShapeUtil::GetDimension(indexed_shape, operand - first_index) - + ShapeUtil::GetDimension(slice_shape, + operand - first_index)); } - } else { - index_space.resize(rank); - for (int64 i = 0; i < rank; ++i) { - index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); + } + break; + } + case HloOpcode::kGather: + case HloOpcode::kScatter: { + const Shape& operand_shape = use->operand(0)->shape(); + if (use->operand(1) == ¶m) { + auto index_map = + use->opcode() == HloOpcode::kGather + ? use->gather_dimension_numbers().start_index_map() + : use->scatter_dimension_numbers() + .scatter_dims_to_operand_dims(); + for (const auto dim_in_operand : index_map) { + index_bound = + std::min(index_bound, operand_shape.dimensions(dim_in_operand)); } } break; @@ -388,13 +472,14 @@ StatusOr CreateLiteralForConstrainedUses( } int constraint_count = 0; constraint_count += no_duplicates ? 1 : 0; - constraint_count += !index_space.empty() ? 1 : 0; + constraint_count += (index_bound != INT64_MAX) ? 1 : 0; constraint_count += needs_constant ? 1 : 0; if (constraint_count > 1) { return Unimplemented("Conflicting operand generation constraints."); } - if (!index_space.empty()) { - return MakeRandomIndex(index_space, engine); + if (index_bound != INT64_MAX) { + return MakeFakeLiteralInternalWithBounds(param.shape(), engine, -1, + index_bound); } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: @@ -459,8 +544,8 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive, std::unique_ptr CreateCanonicalDot(const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { - CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); - CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + CHECK_EQ(lhs->shape().rank(), 2); + CHECK_EQ(rhs->shape().rank(), 2); PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( 2, PrecisionConfig::DEFAULT); diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index e8f5d7a9a79ebddea3cb989dbe8eab90b630d5e7..f68ee04565f3898bd3db455e3e102bc2edb6255a 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -61,11 +61,11 @@ XLA_TEST_F(TestUtilsTest, Token) { R"(HloModule outfeed_module ENTRY InfeedToOutfeed { - token = token[] parameter(0) - infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + token0 = token[] parameter(0) + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) - ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + outfeed = token[] outfeed(infeed.data, token0) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 infeed.1.token = token[] get-tuple-element(infeed.1), index=1 outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) @@ -79,25 +79,27 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { R"(HloModule index_space_module ENTRY IndexSpace { - index_param = s32[3]{0} parameter(0) - array_param.1 = f32[123,4,789]{0,1,2} parameter(1) - array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) - dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3} - ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} + index_param.0 = s32[] parameter(0) + index_param.1 = s32[] parameter(1) + index_param.2 = s32[] parameter(2) + array_param.1 = f32[123,4,789]{0,1,2} parameter(3) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(4) + dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param.0, index_param.1, index_param.2), dynamic_slice_sizes={1,2,3} + ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param.0, index_param.1, index_param.2), dynamic_slice_sizes={3,2,2} })") .ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); - ASSERT_EQ(args.size(), 3); - const Literal& index_arg = args[0]; + ASSERT_EQ(args.size(), 5); - EXPECT_EQ(index_arg.Get({0}), 0); + EXPECT_GE(args[0].Get({}), -1); + EXPECT_LE(args[0].Get({}), 1); - EXPECT_GE(index_arg.Get({1}), 0); - EXPECT_LE(index_arg.Get({1}), 2); + EXPECT_GE(args[1].Get({}), -1); + EXPECT_LE(args[1].Get({}), 2); - EXPECT_GE(index_arg.Get({2}), 0); - EXPECT_LE(index_arg.Get({2}), 3); + EXPECT_GE(args[2].Get({}), -1); + EXPECT_LE(args[2].Get({}), 3); } XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { @@ -105,28 +107,30 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { R"(HloModule index_space_module ENTRY IndexSpace { - index_param = s32[3]{0} parameter(0) - array_param.1 = f32[123,4,789]{0,1,2} parameter(1) - array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) - update_param.1 = f32[1,2,3]{0,1,2} parameter(3) - update_param.2 = f32[3,2,2]{0,1,2} parameter(4) - - dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param) - ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) + index_param.0 = s32[] parameter(0) + index_param.1 = s32[] parameter(1) + index_param.2 = s32[] parameter(2) + array_param.1 = f32[123,4,789]{0,1,2} parameter(3) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(4) + update_param.1 = f32[1,2,3]{0,1,2} parameter(5) + update_param.2 = f32[3,2,2]{0,1,2} parameter(6) + + dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param.0, index_param.1, index_param.2) + ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param.0, index_param.1, index_param.2) })") .ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); - ASSERT_EQ(args.size(), 5); - const Literal& index_arg = args[0]; + ASSERT_EQ(args.size(), 7); - EXPECT_EQ(index_arg.Get({0}), 0); + EXPECT_GE(args[0].Get({}), -1); + EXPECT_LE(args[0].Get({}), 1); - EXPECT_GE(index_arg.Get({1}), 0); - EXPECT_LE(index_arg.Get({1}), 2); + EXPECT_GE(args[1].Get({}), -1); + EXPECT_LE(args[1].Get({}), 2); - EXPECT_GE(index_arg.Get({2}), 0); - EXPECT_LE(index_arg.Get({2}), 3); + EXPECT_GE(args[2].Get({}), -1); + EXPECT_LE(args[2].Get({}), 3); } XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) { @@ -134,10 +138,18 @@ XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) { auto module = ParseHloString(R"( HloModule sort.148.1589 +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) { %parameter.0 = f32[1048576]{0} parameter(0) %parameter.1 = s32[1048576]{0} parameter(1) - ROOT %sort.148.1589 = (f32[1048576]{0}, s32[1048576]{0}) sort(f32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0} + ROOT %sort.148.1589 = (f32[1048576]{0}, s32[1048576]{0}) sort(f32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0}, to_apply=compare } )") .ValueOrDie(); @@ -157,10 +169,18 @@ XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) { auto module = ParseHloString(R"( HloModule sort.148.1589 +compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) { %parameter.0 = s32[1048576]{0} parameter(0) %parameter.1 = s32[1048576]{0} parameter(1) - ROOT %sort.148.1589 = (s32[1048576]{0}, s32[1048576]{0}) sort(s32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0} + ROOT %sort.148.1589 = (s32[1048576]{0}, s32[1048576]{0}) sort(s32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0}, to_apply=compare } )") .ValueOrDie(); @@ -180,10 +200,18 @@ XLA_TEST_F(TestUtilsTest, NoDuplicatesBfloat16) { auto module = ParseHloString(R"( HloModule sort, is_scheduled=true +compare { + p.0.lhs = bf16[] parameter(0) + p.0.rhs = bf16[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,1452], s32[2,1452]) { %parameter.0 = bf16[2,1452]{1,0} parameter(0) %parameter.1 = s32[2,1452]{1,0} parameter(1) - ROOT %sort = (bf16[2,1452]{1,0}, s32[2,1452]{1,0}) sort(bf16[2,1452]{1,0} %parameter.0, s32[2,1452]{1,0} %parameter.1), dimensions={1} + ROOT %sort = (bf16[2,1452]{1,0}, s32[2,1452]{1,0}) sort(bf16[2,1452]{1,0} %parameter.0, s32[2,1452]{1,0} %parameter.1), dimensions={1}, to_apply=compare } )") .ValueOrDie(); @@ -198,5 +226,105 @@ ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,14 } } +XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsR0InputToDynamicSlice) { + auto module = ParseHloString(R"( +HloModule Test + +ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] { + %parameter.1 = f32[20,20]{1,0} parameter(1) + %constant.1 = s32[1]{0} constant({0}) + %parameter.0 = s32[] parameter(0) + %bitcast.3 = s32[1]{0} bitcast(s32[] %parameter.0) + %concatenate.1 = s32[2]{0} concatenate(s32[1]{0} %constant.1, s32[1]{0} %bitcast.3), dimensions={0} + %dynamic-slice.2 = f32[20,1]{1,0} dynamic-slice(f32[20,20]{1,0} %parameter.1, s32[2]{0} %concatenate.1), dynamic_slice_sizes={20,1} + %bitcast.4 = f32[20]{0} bitcast(f32[20,1]{1,0} %dynamic-slice.2) + %dynamic-slice.3 = f32[1]{0} dynamic-slice(f32[20]{0} %bitcast.4, s32[1]{0} %bitcast.3), dynamic_slice_sizes={1} + ROOT %bitcast.5 = f32[] bitcast(f32[1]{0} %dynamic-slice.3) +} +)") + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + EXPECT_TRUE(ShapeUtil::Equal(args[0].shape(), ShapeUtil::MakeShape(S32, {}))) + << ShapeUtil::HumanString(args[0].shape()); + EXPECT_TRUE( + ShapeUtil::Equal(args[1].shape(), ShapeUtil::MakeShape(F32, {20, 20}))) + << ShapeUtil::HumanString(args[1].shape()); +} + +XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForGather) { + auto module = ParseHloString(R"( + HloModule Test + +ENTRY %module(paramater.0: f32[200,100,300], parameter.1: s32[10,2]) -> + f32[10,300] { + %parameter.0 = f32[200,100,300] parameter(0) + %parameter.1 = s32[10,2] parameter(1) + ROOT gather = f32[10,300] gather(f32[200,100,300] %parameter.0, + s32[10,2] %parameter.1), + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, + index_vector_dim=1, + slice_sizes={1,1,300} +} +)") + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + + const Shape& indices_shape = args[1].shape(); + EXPECT_TRUE( + ShapeUtil::Equal(indices_shape, ShapeUtil::MakeShape(S32, {10, 2}))) + << ShapeUtil::HumanString(indices_shape); + auto indices = args[1].data(); + for (const auto index : indices) { + EXPECT_GE(index, -1); + EXPECT_LE(index, 100); + } +} + +XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForScatter) { + auto module = ParseHloString(R"( + HloModule Test + +scatter_update (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + ROOT rhs = f32[] parameter(1) +} + +ENTRY main { + operand = f32[200,100,300] parameter(0) + indices = s32[10,2] parameter(1) + updates = f32[10,300] parameter(2) + ROOT scatter = f32[200,100,300] scatter(operand, indices, updates), + to_apply=scatter_update, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 + } +)") + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 3); + + const Shape& indices_shape = args[1].shape(); + EXPECT_TRUE( + ShapeUtil::Equal(indices_shape, ShapeUtil::MakeShape(S32, {10, 2}))) + << ShapeUtil::HumanString(indices_shape); + auto indices = args[1].data(); + for (const auto index : indices) { + EXPECT_GE(index, -1); + EXPECT_LE(index, 100); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index 601c6b06938fef1f1ae809b33209ae59b24c70a2..b77cf38ed8e29973985406015c0a3936916ad5e6 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -214,8 +214,8 @@ ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] { %forty_two = f32[] constant(42.0) %add = f32[] add(f32[] %p0, f32[] %forty_two) - %token = token[] after-all(f32[] %add) - %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token) + %token0 = token[] after-all(f32[] %add) + %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token0) %neg = f32[] negate(f32[] %p1_after_token) ROOT %product = f32[] multiply(f32[] %add, f32[] %neg) } @@ -236,8 +236,8 @@ HloModule AddDependencyOfConstant, is_scheduled=true ENTRY %AddDependency (p0: f32[]) -> f32[] { %p0 = f32[] parameter(0) %forty_two = f32[] constant(42.0) - %token = token[] after-all(f32[] %p0) - %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token) + %token0 = token[] after-all(f32[] %p0) + %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token0) ROOT %product = f32[] multiply(f32[] %p0, f32[] %forty_two_after_token) } )"; @@ -255,8 +255,8 @@ HloModule AddDependencyAsRoot, is_scheduled=true ENTRY %AddDependency (p: f32[3]) -> f32[3] { %p = f32[3] parameter(0) %neg = f32[3] negate(f32[3] %p) - %token = token[] after-all() - ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token) + %token0 = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token0) } )"; TF_ASSERT_OK_AND_ASSIGN( @@ -274,9 +274,9 @@ ENTRY %TupleShapedAddDependency (p0: f32[3], p1: f32[3]) -> f32[3] { %p0 = f32[3] parameter(0) %p1 = f32[3] parameter(1) %forty_two = f32[] constant(42.0) - %token = token[] after-all() - %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token, f32[3] %p1, f32[] %forty_two) - %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token) + %token0 = token[] after-all() + %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token0, f32[3] %p1, f32[] %forty_two) + %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token0) %elem0 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=0 %elem2 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=2 ROOT %diff = f32[3] subtract(f32[3] %elem0, f32[3] %elem2) diff --git a/tensorflow/compiler/xla/tests/triangular_solve_test.cc b/tensorflow/compiler/xla/tests/triangular_solve_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..24ab12136ff396bd9ac37bb058311b0d2d6f2515 --- /dev/null +++ b/tensorflow/compiler/xla/tests/triangular_solve_test.cc @@ -0,0 +1,502 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using TriangularSolveTest = ClientLibraryTestBase; +using TriangularSolveLeftLookingTest = ClientLibraryTestBase; + +static constexpr float kNan = std::numeric_limits::quiet_NaN(); + +Array2D AValsLower() { + return {{2, kNan, kNan, kNan}, + {3, 6, kNan, kNan}, + {4, 7, 9, kNan}, + {5, 8, 10, 11}}; +} + +Array2D AValsUpper() { + return {{2, 3, 4, 5}, + {kNan, 6, 7, 8}, + {kNan, kNan, 9, 10}, + {kNan, kNan, kNan, 11}}; +} + +Array2D AValsLowerUnitDiagonal() { + return {{kNan, kNan, kNan, kNan}, + {3, kNan, kNan, kNan}, + {4, 7, kNan, kNan}, + {5, 8, 10, kNan}}; +} + +Array2D AValsUpperUnitDiagonal() { + return {{kNan, 3, 4, 5}, + {kNan, kNan, 7, 8}, + {kNan, kNan, kNan, 10}, + {kNan, kNan, kNan, kNan}}; +} + +Array2D BValsRight() { + return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; +} + +Array2D BValsLeft() { + return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; +} + +static constexpr complex64 kNanC64 = complex64(kNan, kNan); + +Array2D AValsLowerComplex() { + return {{2, kNanC64, kNanC64, kNanC64}, + {complex64(3, 1), 6, kNanC64, kNanC64}, + {4, complex64(7, 2), 9, kNanC64}, + {5, 8, complex64(10, 3), 11}}; +} + +Array2D AValsUpperComplex() { + return {{2, 3, complex64(4, 3), 5}, + {kNanC64, 6, complex64(7, 2), 8}, + {kNanC64, kNanC64, complex64(9, 1), 10}, + {kNanC64, kNanC64, kNanC64, 11}}; +} + +Array2D BValsRightComplex() { + return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; +} + +Array2D BValsLeftComplex() { + return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; +} + +XLA_TEST_F(TriangularSolveTest, EmptyArrays) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = + CreateR2Parameter(Array2D(0, 0), 0, "a", &builder, &a); + auto b_data = + CreateR2Parameter(Array2D(0, 10), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); + + ComputeAndCompareR2(&builder, Array2D(0, 10), + {a_data.get(), b_data.get()}); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); + + Array2D expected({ + {0.5, 0.08333334, 0.04629629, 0.03367003}, + {2.5, -0.25, -0.1388889, -0.1010101}, + {4.5, -0.58333331, -0.32407406, -0.23569024}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); + + Array2D expected({ + {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, + {0.64393939, 0.06565657, -0.03030303, 0.72727273}, + {1.4520202, 0.2003367, 0.01010101, 1.09090909}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/false, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); + + Array2D expected({ + {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, + {0.64393939, 0.06565657, -0.03030303, 0.72727273}, + {1.4520202, 0.2003367, 0.01010101, 1.09090909}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/false, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); + + Array2D expected({ + {0.5, 0.08333334, 0.04629629, 0.03367003}, + {2.5, -0.25, -0.1388889, -0.1010101}, + {4.5, -0.58333331, -0.32407406, -0.23569024}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); + + Array2D expected({ + {-0.89646465, -0.69444444, -0.49242424}, + {-0.27441077, -0.24074074, -0.20707071}, + {-0.23232323, -0.22222222, -0.21212121}, + {0.90909091, 1., 1.09090909}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); + + Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNoTransposeUnitDiagonal) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = + CreateR2Parameter(AValsLowerUnitDiagonal(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*unit_diagonal=*/true, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); + + Array2D expected( + {{1., 2., 3.}, {1., -1., -3.}, {-4., 7., 18.}, {37., -61., -159.}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); + + Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); + + Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); + + Array2D expected({ + {-0.89646465, -0.69444444, -0.49242424}, + {-0.27441077, -0.24074074, -0.20707071}, + {-0.23232323, -0.22222222, -0.21212121}, + {0.90909091, 1., 1.09090909}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotransposeUnitDiagonal) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = + CreateR2Parameter(AValsUpperUnitDiagonal(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*unit_diagonal=*/true, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); + + Array2D expected({{-1402., -1538., -1674.}, + {575., 631., 687.}, + {-93., -102., -111.}, + {10., 11., 12.}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = + CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); + auto b_data = + CreateR2Parameter(BValsRightComplex(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::ADJOINT); + + Array2D expected({ + {0.5, complex64(0.08333333, 0.08333333), + complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)}, + {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963), + complex64(0.08670034, -0.02104377)}, + {4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296), + complex64(0.11026936, -0.03114478)}, + }); + + ComputeAndCompareR2( + &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = + CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); + auto b_data = + CreateR2Parameter(BValsLeftComplex(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); + + Array2D expected({ + {0.5, 1., 1.5}, + {0.41666667, 0.33333333, 0.25}, + {complex64(0.20020325, -2.81504065e-01), + complex64(0.13821138, -4.22764228e-01), + complex64(0.07621951, -5.64024390e-01)}, + {complex64(0.19678492, 2.55912786e-01), + complex64(0.17738359, 3.84331116e-01), + complex64(0.15798226, 5.12749446e-01)}, + }); + + ComputeAndCompareR2( + &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) { + XlaBuilder builder(TestName()); + + Array3D bvals(7, 5, 5); + bvals.FillIota(1.); + + // Set avals to the upper triangle of bvals. + Array3D avals = bvals; + avals.Each([](absl::Span indices, float* value) { + if (indices[1] > indices[2]) { + *value = 0; + } + }); + + XlaOp a, b; + auto a_data = CreateR3Parameter(avals, 0, "a", &builder, &a); + auto b_data = CreateR3Parameter(bvals, 1, "b", &builder, &b); + BatchDot( + ConstantR3FromArray3D(&builder, avals), + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE)); + + ComputeAndCompareR3(&builder, bvals, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +struct TriangularSolveTestSpec { + int m, n; // A is mxm, B is mxn + bool left_side; + bool lower; + TriangularSolveOptions::Transpose transpose_a; +}; + +class TriangularSolveParametricTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(TriangularSolveParametricTest, Random) { + TriangularSolveTestSpec spec = GetParam(); + + XlaBuilder builder(TestName()); + + Array2D avals(spec.m, spec.m); + avals.FillRandom(1.0); + for (int i = 0; i < spec.m; ++i) { + avals(i, i) += 10; + } + + std::pair bdims = spec.left_side ? std::make_pair(spec.m, spec.n) + : std::make_pair(spec.n, spec.m); + Array2D bvals(bdims.first, bdims.second); + bvals.FillRandom(1.0); + + XlaOp a, b; + auto a_data = CreateR2Parameter(avals, 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(bvals, 1, "b", &builder, &b); + auto x = TriangularSolve(a, b, spec.left_side, spec.lower, + /*unit_diagonal=*/false, spec.transpose_a); + auto a_tri = Triangle(a, spec.lower); + a_tri = MaybeTransposeInMinorDims( + a_tri, spec.transpose_a != TriangularSolveOptions::NO_TRANSPOSE); + if (spec.left_side) { + BatchDot(a_tri, x); + } else { + BatchDot(x, a_tri); + } + + ComputeAndCompareR2(&builder, bvals, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +std::vector TriangularSolveTests() { + std::vector specs; + for (int m : {5, 10}) { + for (int n : {5, 10}) { + for (bool left_side : {false, true}) { + for (bool lower : {false, true}) { + for (TriangularSolveOptions::Transpose transpose_a : + {TriangularSolveOptions::NO_TRANSPOSE, + TriangularSolveOptions::TRANSPOSE}) { + specs.push_back({m, n, left_side, lower, transpose_a}); + } + } + } + } + } + return specs; +} + +INSTANTIATE_TEST_SUITE_P(TriangularSolveParametricTestInstantiation, + TriangularSolveParametricTest, + ::testing::ValuesIn(TriangularSolveTests())); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 27ce243e9bd4afbdcc1fdc5b6873d4968086e459..cdf2c34fcc3cc005e84626c39c8ab301a9040529 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -176,8 +176,9 @@ XLA_TEST_F(TupleTest, AddTupleElements) { {2.f, 4.f, 6.f}, // row 0 {5.f, 7.f, 9.f}, // row 1 }); - ASSERT_TRUE(ShapeUtil::ShapeIs(vector_shape, F32, {3})); - ASSERT_TRUE(ShapeUtil::ShapeIs(matrix_shape, F32, {/*y=*/2, /*x=*/3})); + ASSERT_TRUE(ShapeUtil::Equal(vector_shape, ShapeUtil::MakeShape(F32, {3}))); + ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, + ShapeUtil::MakeShape(F32, {/*y=*/2, /*x=*/3}))); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } @@ -512,8 +513,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { class TupleHloTest : public HloTestBase {}; -// Disabled on the interpreter because bitcast doesn't exist on the interpreter. -XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { +XLA_TEST_F(TupleHloTest, BitcastAfterGTE) { const char* testcase = R"( HloModule m, is_scheduled=true @@ -525,9 +525,7 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy) } )"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({1, 2, 3})); auto result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -555,13 +553,11 @@ XLA_TEST_F(TupleHloTest, s = (f32[2],f32[2]) tuple-select(cond, tup0, tup1) gte = f32[2] get-tuple-element(s), index=0 tuple = (f32[2]) tuple(gte) - token = token[] after-all() - ROOT outfeed = token[] outfeed(tuple, token) + token0 = token[] after-all() + ROOT outfeed = token[] outfeed(tuple, token0) } )"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param0 = LiteralUtil::CreateR1({1, 2}); auto param1 = LiteralUtil::CreateR1({2, 3}); auto param4 = LiteralUtil::CreateR0(false); diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 4fbd7f2fb174ac899c1e3b23801986cb52db96a2..c51f30f3b5db95962a719ec226dd03f41142a782 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -64,7 +64,9 @@ class UnaryOpTest : public ClientLibraryTestBase { &builder, {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); Sign(arg); - ComputeAndCompareR1(&builder, {-1, 1, 0, 0, -1, 1, -1}, {}); + ComputeAndCompareR1( + &builder, + {-1, 1, static_cast(+0.0), static_cast(-0.0), -1, 1, -1}, {}); } template diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 6d5f276e82087cedc356691b0ff08df24cec8d20..85212fa56d71088156d2f3edda17f71cdab56da2 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -861,7 +861,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Update. auto update = ConvertElementType(Broadcast(out0, {2}), F32); // Starts = iteration * 2; - auto starts = Reshape(Mul(iteration, ConstantR0(&builder, 2)), {1}); + auto starts = Mul(iteration, ConstantR0(&builder, 2)); // UpdateSlice. auto out1 = DynamicUpdateSlice(input, update, starts); @@ -901,7 +901,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Per backend the values generated can be different as the different backends // use different random number generators. // TODO(b/32240857): Extend test to verify outputs. -XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { +XLA_TEST_F(WhileTest, WhileWithPrngScalarResult) { auto v6s32 = ShapeUtil::MakeShape(S32, {6}); // Create a computation for the condition: repeat for count iterations. @@ -1146,7 +1146,7 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // while (f(result).get<0>()) { // result = result + 1; // } -XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { +XLA_TEST_F(WhileTest, WhileWithCallInsideCondition) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. @@ -1299,9 +1299,9 @@ void BM_WhileLoop(int num_iters) { auto one = ConstantR0(&builder, 1.0); auto update = Broadcast(one, {1, 1024, 1024}); // Starts = iteration * 2; - auto starts = ConstantR1(&builder, {0, 0, 0}); + auto zero = ConstantR0(&builder, 0); // UpdateSlice. - auto out1 = DynamicUpdateSlice(input, update, starts); + auto out1 = DynamicUpdateSlice(input, update, {zero, zero, zero}); Tuple(&builder, {out0, out1}); body = builder.Build().ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index e57d072a0632b492b8b6e34439f4e80332b843b6..7b7b8f5d02dc99607b30f898e18c5b448d421e07 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -40,8 +40,6 @@ limitations under the License. namespace xla { namespace { -namespace gtl = ::tensorflow::gtl; - class HloProfileTest : public ClientLibraryTestBase {}; struct ParsedProfileOutputLine { @@ -174,9 +172,8 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, exec_run_options.set_allocator(backend->memory_allocator()); exec_run_options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); - ServiceExecutableRunOptions run_options( - exec_run_options, /*borrow_stream=*/nullptr, - backend->eigen_intra_op_thread_pool()); + ServiceExecutableRunOptions run_options(exec_run_options, + /*borrow_stream=*/nullptr); std::vector args = {&lhs_arg, &rhs_arg}; TF_ASSERT_OK_AND_ASSIGN( auto execution_result, @@ -225,14 +222,17 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { line_no++; // Skip 'Execution profile for ....' + ASSERT_LT(line_no, profile_output_lines.size()); TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], /*expect_hlo=*/false, &parsed_profile_lines)); + ASSERT_LT(line_no, profile_output_lines.size()); TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], /*expect_hlo=*/true, &parsed_profile_lines)); + ASSERT_LT(line_no, profile_output_lines.size()); TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], /*expect_hlo=*/true, &parsed_profile_lines)); diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index cdde88c1359416d423685f330e9cbdf77948034f..c78ec522aa5f13556c6d4602267544694887f250 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -66,7 +67,7 @@ StatusOr TextLiteralReader::ReadAllLines() { } absl::StripAsciiWhitespace(&shape_string); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); + TF_ASSIGN_OR_RETURN(Shape shape, ParseShape(shape_string)); if (shape.element_type() != F32) { return Unimplemented( "unsupported element type for text literal reading: %s", diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 7289ae7df65e56652eeeb67e536e4c721d97d999..fc7949d889dc8ed9fac425982cc555a6c42a7f1d 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 8926bbed2b54fceaaf0e6e991f0e881d35731ef4..ebd4bb1e42c9d1dc1f72a75514e916a2d900c30e 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -14,7 +14,7 @@ filegroup( visibility = ["//tensorflow/compiler/xla:internal"], ) -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") tf_cc_binary( name = "hex_floats_to_packed_literal", @@ -29,33 +29,6 @@ tf_cc_binary( ], ) -cc_library( - name = "dumped_computation_to_graphviz_library", - srcs = ["dumped_computation_to_graphviz.cc"], - deps = [ - "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/core:lib", - "@com_google_absl//absl/types:span", - ], -) - -tf_cc_binary( - name = "dumped_computation_to_graphviz", - deps = [ - ":dumped_computation_to_graphviz_library", - "//tensorflow/compiler/xla/service:interpreter_plugin", - ], -) - tf_cc_binary( name = "show_signature", srcs = ["show_signature.cc"], @@ -95,6 +68,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service/gpu:infeed_manager", + "//tensorflow/compiler/xla/service/gpu:outfeed_manager", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -204,33 +178,66 @@ tf_cc_binary( ) tf_cc_binary( - name = "dumped_computation_to_tf_graphdef", - srcs = ["dumped_computation_to_tf_graphdef.cc"], + name = "hlo_proto_to_json", + srcs = ["hlo_proto_to_json.cc"], deps = [ - "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:interpreter_plugin", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "hlo_extractor_test", + srcs = ["hlo_extractor_test.cc"], + deps = [ + ":hlo_extractor", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "hlo_extractor", + srcs = ["hlo_extractor.cc"], + hdrs = ["hlo_extractor.h"], + deps = [ + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_verifier", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", ], ) tf_cc_binary( - name = "hlo_proto_to_json", - srcs = ["hlo_proto_to_json.cc"], + name = "interactive_graphviz", + srcs = ["interactive_graphviz.cc"], deps = [ - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", + ":hlo_extractor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:local_service", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) + +sh_test( + name = "interactive_graphviz_build_only_test", + srcs = ["interactive_graphviz_test.sh"], + data = [":interactive_graphviz"], +) diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc deleted file mode 100644 index b623556468fb4a5d96be614b6c067d5a1df51a6f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Usage: dumped_computation_to_graphviz some_binary_snapshot_proto* -// -// Dumps a graphviz URL for a snapshot computation to the command line. -// -// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from -// ServiceInterface::SnapshotComputation to disk. -// -// The GraphViz URL is placed into the log stderr, whereas computation -// statistics are printed on stdout (implementation note: getting computation -// statistics is how we trigger compilation to split out a GraphViz URL). - -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/debug_options_flags.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { -namespace tools { - -void RealMain(absl::Span args) { - Client* client = ClientLibrary::LocalClientOrDie(); - for (char* arg : args) { - HloSnapshot module; - TF_CHECK_OK( - tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - XlaComputation computation = - client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = GetDebugOptionsFromFlags(); - debug_options.set_xla_generate_hlo_graph(".*"); - ComputationStats stats = - client->GetComputationStats(computation, debug_options) - .ConsumeValueOrDie(); - fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); - } -} - -} // namespace tools -} // namespace xla - -int main(int argc, char** argv) { - std::vector flag_list; - xla::AppendDebugOptionsFlags(&flag_list); - xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_result) { - LOG(ERROR) << "\n" << usage; - return 2; - } - tensorflow::port::InitMain(argv[0], &argc, &argv); - - absl::Span args(argv, argc); - args.remove_prefix(1); // Pop off the binary name, argv[0] - xla::tools::RealMain(args); - return 0; -} diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index 4375e7c138c9e8d193feaa7a39d63946c4ea3086..df2d3d18b9ff86c0dd2047c2415527aeb1c1f154 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 723569862c7550387e95003e3a673743464b67b8..35bb82ca22f46d2cdeaac3b9a87b253efe9a07d9 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc deleted file mode 100644 index f8bb9a6b1e217fc4e6e15c8a3302be61ed339c82..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Usage: dumped_computation_to_tf_graph some_binary_snapshot_proto* -// -// Dumps a tensorflow GraphDef in text format for a snapshot computation. The -// dumped graph is an HLO computation with HLO instructions as nodes and can be -// visualized on Tensorboard. Upload the dumped files on Tensorboard. -// -// some_binary_snapshot_proto is obtained by serializing the SessionModule from -// ServiceInterface::SnapshotComputation to disk. - -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/debug_options_flags.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" - -using tensorflow::Env; - -namespace xla { -namespace tools { - -void RealMain(absl::Span args) { - Client* client = ClientLibrary::LocalClientOrDie(); - for (char* arg : args) { - HloSnapshot module; - TF_CHECK_OK( - tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - XlaComputation computation = - client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = GetDebugOptionsFromFlags(); - debug_options.set_xla_generate_hlo_graph(".*"); - debug_options.set_xla_hlo_dump_as_graphdef(true); - ComputationStats stats = - client->GetComputationStats(computation, debug_options) - .ConsumeValueOrDie(); - fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); - } -} - -} // namespace tools -} // namespace xla - -int main(int argc, char** argv) { - std::vector flag_list; - xla::AppendDebugOptionsFlags(&flag_list); - xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_result) { - LOG(ERROR) << "\n" << usage; - return 2; - } - - tensorflow::port::InitMain(argv[0], &argc, &argv); - - absl::Span args(argv, argc); - args.remove_prefix(1); // Pop off the binary name, argv[0] - xla::tools::RealMain(args); - return 0; -} diff --git a/tensorflow/compiler/xla/tools/hlo_extractor.cc b/tensorflow/compiler/xla/tools/hlo_extractor.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3ce5f99b0c2a8e9ae5446f4bedc34b678c95b96 --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_extractor.cc @@ -0,0 +1,159 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/hlo_extractor.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { +namespace { + +// Visitor that build a new HLO module with an entry computation and a root that +// is provided to the visit function. Only HLOs that are reachable from the new +// root instruction are included in the new module. +// +// The constructor allows specifying a set of boundary HLOs to prune the HLO +// graph. HLOs at the boundary are replaced with parameters. Can be nullptr +// which means no boundary, i.e. no HLOs are replaced with parameters. +class ExtractionVisitor : public ConstDfsHloVisitorWithDefault { + public: + explicit ExtractionVisitor( + const HloModule& old_module, + absl::flat_hash_set* boundary) + : old_module_(old_module), + module_(absl::make_unique("extracted", config_)), + clone_context_(module_.get()), + builder_("entry_computation"), + boundary_(boundary) {} + + Status HandleParameter(const HloInstruction* parameter) override { + // Entry parameters need renumbering. + auto new_parameter = HloInstruction::CreateParameter( + parameter_number_++, parameter->shape(), parameter->name()); + clone_context_.MapInstruction(parameter, new_parameter.get()); + builder_.AddInstruction(std::move(new_parameter)); + return Status::OK(); + } + + Status DefaultAction(const HloInstruction* hlo) override { + // Replace instructions at the boundary with parameters, but leave constants + // untouched. + if (boundary_ != nullptr && boundary_->count(hlo) > 0) { + auto new_parameter = HloInstruction::CreateParameter( + parameter_number_, hlo->shape(), hlo->name()); + parameter_number_++; + clone_context_.MapInstruction(hlo, new_parameter.get()); + builder_.AddInstruction(std::move(new_parameter)); + return Status::OK(); + } + std::vector new_operands; + for (auto operand : hlo->operands()) { + new_operands.push_back(clone_context_.GetInstruction(operand)); + } + auto instruction = + hlo->CloneWithNewOperands(hlo->shape(), new_operands, &clone_context_); + builder_.AddInstruction(std::move(instruction)); + return Status::OK(); + } + + Status FinishVisit(const HloInstruction* /*root*/) override { + module_->AddEntryComputation(builder_.Build()); + // Rename HLOs so that their name matches the original. By default, + // HLOs get new unique names when adding a new entry computation to + // a module. + for (auto computation : old_module_.MakeComputationPostOrder()) { + for (auto old_instruction : computation->MakeInstructionPostOrder()) { + if (auto new_instruction = + clone_context_.FindInstruction(old_instruction)) { + new_instruction->SetAndSanitizeName(old_instruction->name()); + } + } + } + return Status::OK(); + } + + HloModule* module() { return module_.get(); } + + std::unique_ptr ConsumeModule() { return std::move(module_); } + + private: + const HloModule& old_module_; + HloModuleConfig config_; + std::unique_ptr module_; + HloCloneContext clone_context_; + HloComputation::Builder builder_; + absl::flat_hash_set* boundary_; + int64 parameter_number_ = 0; +}; + +void ComputeBoundary(const HloInstruction* root, int64 limit, + absl::flat_hash_set* boundary) { + std::deque worklist; + absl::flat_hash_map visited; + worklist.push_back(root); + visited.emplace(root, 0); + while (!worklist.empty()) { + auto hlo = worklist.front(); + worklist.pop_front(); + int64 hops = visited[hlo]; + if (hops > limit) { + boundary->insert(hlo); + continue; + } + for (const HloInstruction* operand : hlo->operands()) { + if (visited.count(operand)) { + continue; + } + worklist.push_back(operand); + visited.emplace(operand, hops + 1); + } + } +} + +} // namespace + +std::unique_ptr ExtractModule(HloInstruction* instruction, + int64 height) { + absl::flat_hash_set boundary; + if (height != -1) { + ComputeBoundary(instruction, height, &boundary); + } + ExtractionVisitor visitor(*instruction->GetModule(), &boundary); + CHECK(instruction->Accept(&visitor).ok()); + + // The first pass may leave unused parameter instructions. Do another + // extraction pass to remove unused parameters. This is done because + // HloComputation does not allow removing parameters after the computation has + // been built. + ExtractionVisitor cleanup_visitor(*visitor.module(), /*boundary=*/nullptr); + TF_CHECK_OK(visitor.module()->entry_computation()->root_instruction()->Accept( + &cleanup_visitor)); + + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true); + TF_CHECK_OK(verifier.Run(cleanup_visitor.module()).status()); + return cleanup_visitor.ConsumeModule(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/hlo_extractor.h b/tensorflow/compiler/xla/tools/hlo_extractor.h new file mode 100644 index 0000000000000000000000000000000000000000..bc13dc7e438fe0e64312746150af02df805e746a --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_extractor.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_HLO_EXTRACTOR_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_HLO_EXTRACTOR_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +// Creates a new HLO module rooted with an entry computation rooted at the given +// instruction. +// +// By default (height == -1), the new computation includes all transitive +// operands of `root`. If you specify a different height, the new computation +// will include all instructions <= `height` hops away from `root`. +// Instructions at the boundary are replaced by parameters. +std::unique_ptr ExtractModule(HloInstruction* instruction, + int64 height = -1); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_HLO_EXTRACTOR_H_ diff --git a/tensorflow/compiler/xla/tools/hlo_extractor_test.cc b/tensorflow/compiler/xla/tools/hlo_extractor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4beb099b330cadf4540944979f38681bae07103c --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_extractor_test.cc @@ -0,0 +1,139 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/hlo_extractor.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +namespace op = testing::opcode_matchers; + +using HloExtractorTest = HloTestBase; + +TEST_F(HloExtractorTest, ExtractTopLevel) { + const string& hlo_string = R"( +HloModule test + +ENTRY %entry { + param.0 = f32[4]{0} parameter(0) + negate = f32[4]{0} negate(f32[4]{0} param.0) + ROOT exp = f32[4]{0} exponential(f32[4]{0} negate) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "exp")); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Exp(op::Negate(op::Parameter(0)))); + } + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "exp"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Exp(op::Parameter(0))); + } + + { + auto extracted_module = ExtractModule( + FindInstruction(hlo_module.get(), "negate"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Negate(op::Parameter(0))); + } +} + +TEST_F(HloExtractorTest, ExtractDag) { + const string& hlo_string = R"( +HloModule test + +ENTRY %entry { + param.0 = f32[4]{0} parameter(0) + tanh = f32[4]{0} tanh(f32[4]{0} param.0) + negate = f32[4]{0} negate(f32[4]{0} tanh) + exp = f32[4]{0} exponential(f32[4]{0} negate) + ROOT add = f32[4]{0} add(f32[4]{0} negate, f32[4]{0} exp) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "exp")); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Exp(op::Negate(op::Tanh(op::Parameter(0))))); + } + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Parameter(0), op::Parameter(1))); + } + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/1); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Negate(op::Parameter(0)), + op::Exp(op::Negate(op::Parameter(0))))); + } + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/2); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Negate(op::Tanh(op::Parameter(0))), + op::Exp(op::Negate(op::Tanh(op::Parameter(0)))))); + } +} + +TEST_F(HloExtractorTest, ExtractWithConstant) { + const string& hlo_string = R"( +HloModule test + +ENTRY %entry { + p = f32[4]{0} parameter(0) + tanh = f32[4]{0} tanh(p) + c = f32[4]{0} constant({1, 2, 3, 4}) + ROOT add = f32[4]{0} add(tanh, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Parameter(0), op::Parameter(1))); + } + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/1); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Tanh(op::Parameter(0)), op::Constant())); + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/interactive_graphviz.cc b/tensorflow/compiler/xla/tools/interactive_graphviz.cc new file mode 100644 index 0000000000000000000000000000000000000000..0c7c078b9b9d30427cb01b8930bd012046d852d3 --- /dev/null +++ b/tensorflow/compiler/xla/tools/interactive_graphviz.cc @@ -0,0 +1,676 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A tool for interactively exploring graphviz dumps of HLO graphs. +// +// Input can be a binary HloSnapshot proto, a binary HloProto proto, or a +// textual HLO string. +// +// Generated visualization is opened in a new default browser window using +// /usr/bin/sensible-browser. + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/tools/hlo_extractor.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/util/command_line_flags.h" +#if defined(PLATFORM_GOOGLE) +#include "util/readline/readline.h" +#endif + +namespace xla { +namespace tools { +namespace { + +bool ReadLine(const char *prompt, string *line) { +#if defined(PLATFORM_GOOGLE) + return util::ReadLine(prompt, line); +#else + std::cout << prompt; + std::getline(std::cin, *line); + return std::cin.good(); +#endif +} + +// Command-line opts to this tool. See main() for descriptions of these +// fields. +struct Options { + string hlo_snapshot; + string hlo_proto; + string hlo_text; + string platform; + string browser; +}; + +const char* const kUsage = R"( +This tool lets you load an XLA dump and then interactively explore its graphical +representation. + +Most models are too large to visualize in their entirety using graphviz, but +it's still useful to be able to look at the nodes "near" a particular node of +interest. + +If you pass --platform, this tool will compile the HloModule for the given +platform. This means that if you acquired your proto from a binary running at a +particular CL, the HLO graph it ran isn't necessarily the same as the one shown +here, unless this program was built at the same CL (and our compiler is +deterministic :). + +Be patient when starting this program if you give it a large input; it has to +compile the whole thing. + +Usage: + + interactive_graphviz -- \ + --{hlo_snapshot,hlo_proto,hlo_text}=path/to/binary_proto + --platform={CUDA,CPU,...} +)"; + +// Unless an explicit width is specified, we will render a neighborhood of +// kDefaultWidth nodes around the requested instruction. +constexpr int64 kDefaultWidth = 2; + +// When printing all paths between two nodes, we print out only this many nodes +// by default, truncating the graph if there are more nodes than this in the +// all-paths set. +constexpr int64 kDefaultMaxNumNodesInAllPaths = 100; + +using absl::EqualsIgnoreCase; + +// A global control for whether backend configuration display is enabled. +bool show_backend_config = true; + +HloInstruction* FindInstruction(const HloModule& module, string node_name) { + if (absl::StartsWith(node_name, "%")) { + node_name.erase(node_name.begin()); + } + for (const auto& computation : module.computations()) { + auto instrs = computation->instructions(); + auto it = absl::c_find_if(instrs, [&](const HloInstruction* instr) { + // Try with and without "%" at the beginning of the node name. + return EqualsIgnoreCase(instr->name(), node_name) || + EqualsIgnoreCase(instr->name(), absl::StrCat("%", node_name)); + }); + if (it != instrs.end()) { + return *it; + } + } + return nullptr; +} + +HloComputation* FindComputation(const HloModule& module, + const string& comp_name) { + for (auto* computation : module.computations()) { + if (EqualsIgnoreCase(computation->name(), comp_name)) { + return computation; + } + } + return nullptr; +} + +// Print a help message describing the various available commands. +void DoHelpCommand() { + std::cout << R"(Commands: + [] [/ +] + Renders a neighborhood of nodes around , without going + beyond the optional boundary instructions. If is not provided, + the default value is )" + << kDefaultWidth << R"(. + allpaths [] + Renders a subset of all paths from one instruction to the other. Either + order of nodes is accepted. Shows the nodes in the all-paths set on the + shortest paths; default is )" + << kDefaultMaxNumNodesInAllPaths << R"(. + + Renders all nodes in . + backend_config [on|off] + Controls whether backend operation configuration information is printed. + list [name|op_name|op_type] + Lists all instructions whose name, metadata op_name, or metadata op_type + contains as a substring. + list computations + Lists all computations in the module. + info + info + Prints information about or . + extract + Creates a new HLO module with as entry computation root. If + is specified, the new computation contains nodes up to + nodes above the root. + help + Prints this usage information.)" + << std::endl; +} + +// Turn metadata-printing on or off. +void DoBackendConfigCommand(const std::vector& tokens) { + if (tokens.size() == 2 && tokens[1] == "on") { + show_backend_config = true; + } else if (tokens.size() == 2 && tokens[1] == "off") { + show_backend_config = false; + } else if (tokens.size() != 1) { + std::cerr << "(Illegal backend_config value. Use either 'on' or 'off'.)" + << std::endl; + } + std::cout << "Backend configuration display " + << (show_backend_config ? "ON" : "OFF") << std::endl; +} + +// List all computations in the module. +void DoListComputationsCommand(const HloModule& module, + const std::vector& tokens) { + if (tokens.size() > 2) { + std::cout << R"(Illegal syntax; "list computations" takes no arguments.)"; + return; + } + if (module.entry_computation() != nullptr) { + std::cout << "Entry computation:" << std::endl; + std::cout << " " << module.entry_computation()->name() << std::endl + << std::endl; + } + std::cout << "Subcomputations:" << std::endl; + std::vector names; + for (const auto& computation : module.computations()) { + if (computation == module.entry_computation()) { + continue; + } + std::cout << " " << computation->name() << std::endl; + } +} + +// List all instructions matching a pattern. +void DoListCommand(const HloModule& module, const std::vector& tokens) { + string pattern = ""; + string type = "name"; + if (tokens.size() == 2) { + pattern = tokens[1]; + } else if (tokens.size() == 3) { + type = tokens[1]; + pattern = tokens[2]; + } else { + std::cout << "Illegal list query syntax. Use " + << R"("list [name|op_name|op_type] pattern".)" << std::endl; + return; + } + + std::cout << "Query results:" << std::endl; + for (const auto& computation : module.computations()) { + for (const auto& instr : computation->instructions()) { + if ((type == "name" && instr->name().find(pattern) != string::npos) || + (type == "op_name" && + instr->metadata().op_name().find(pattern) != string::npos) || + (type == "op_type" && + instr->metadata().op_type().find(pattern) != string::npos)) { + std::cout << " " << instr->name(); + std::cout << ", op_name '" << instr->metadata().op_name() << "'"; + std::cout << ", op_type '" << instr->metadata().op_type() << "'"; + std::cout << std::endl; + } + } + } +} + +// Print info about an instruction or computation. +void DoInfoCommand(const HloModule& module, const std::vector& tokens) { + if (tokens.size() != 2) { + std::cerr << "Illegal info query syntax. Use " + << R"("info name".)"; + return; + } + string node_name = tokens[1]; + + const HloInstruction* instr = FindInstruction(module, node_name); + const HloComputation* comp = FindComputation(module, node_name); + if (!instr && !comp) { + std::cerr << "Couldn't find HloInstruction or HloComputation named " + << node_name << std::endl; + return; + } + + if (comp != nullptr) { + std::cout << "HloComputation " << comp->name() << std::endl; + if (comp->IsFusionComputation()) { + std::cout << " Fusion instruction: " << comp->FusionInstruction()->name() + << std::endl; + } + std::cout << " Parameters:" << std::endl; + for (const auto& param : comp->parameter_instructions()) { + std::cout << " " << param->name() << " (" + << ShapeUtil::HumanStringWithLayout(param->shape()) << ")" + << std::endl; + } + HloInstruction* root = comp->root_instruction(); + std::cout << " Root instruction: " << root->name() << " (" + << ShapeUtil::HumanStringWithLayout(root->shape()) << ")" + << std::endl; + + auto embedded_computations = comp->MakeEmbeddedComputationsList(); + std::cout << " " << embedded_computations.size() << " embedded computation" + << (embedded_computations.size() != 1 ? "s" : "") + << (!embedded_computations.empty() ? ":" : ".") << std::endl; + for (const HloComputation* c : embedded_computations) { + std::cout << " " << c->name() << std::endl; + } + + // Find which computations reference comp as an embedded computation. + std::vector users; + for (const HloComputation* c : module.computations()) { + if (absl::c_linear_search(c->MakeEmbeddedComputationsList(), comp)) { + users.push_back(c); + } + } + std::cout << " Used by " << users.size() << " computation" + << (users.size() != 1 ? "s" : "") << (!users.empty() ? ":" : "."); + for (const HloComputation* c : users) { + std::cout << " " << c->name() << std::endl; + } + } else { + std::cout << "HloInstruction " << instr->name() << std::endl; + std::cout << " Parent computation: " << instr->parent()->name() + << std::endl; + std::cout << " Opcode: " << HloOpcodeString(instr->opcode()) << std::endl; + std::cout << " Shape: " << ShapeUtil::HumanStringWithLayout(instr->shape()) + << std::endl; + std::cout << " Metadata:" << std::endl; + if (!instr->metadata().op_name().empty()) { + std::cout << " Name: " << instr->metadata().op_name() << std::endl; + } + if (!instr->metadata().op_type().empty()) { + std::cout << " Type: " << instr->metadata().op_type() << std::endl; + } + if (!instr->raw_backend_config_string().empty()) { + std::cout << " Backend configuration: " + << instr->raw_backend_config_string() << std::endl; + } + if (instr->opcode() == HloOpcode::kFusion) { + std::cout << " Fusion kind: " << xla::ToString(instr->fusion_kind()) + << std::endl; + std::cout << " Fusion computation: " + << instr->fused_instructions_computation()->name() << std::endl; + std::cout << " Fused computation root: " + << instr->fused_expression_root()->name() << std::endl; + } + std::cout << " Operands:" << std::endl; + for (HloInstruction* operand : instr->operands()) { + std::cout << " " << operand->name() << " (" + << ShapeUtil::HumanStringWithLayout(operand->shape()) << ")" + << std::endl; + } + std::cout << " Users:" << std::endl; + for (HloInstruction* user : instr->users()) { + std::cout << " " << user->name() << std::endl; + } + if (instr->parent()->root_instruction() == instr) { + std::cout << " Root instruction of " << instr->parent()->name() + << std::endl; + } + } +} + +void DoExtractCommand(const HloModule& module, + absl::Span tokens) { + if (tokens.size() > 3) { + std::cerr << R"(Illegal input. Enter e.g. "extract %fusion.1 2")" + << std::endl; + return; + } + + // Find the node with the given name. + string node_name = tokens[1]; + HloInstruction* instr = FindInstruction(module, node_name); + if (!instr) { + std::cerr << "Couldn't find HloInstruction named " << node_name << "." + << std::endl; + return; + } + + int64 height = -1; + if (tokens.size() == 3) { + if (!absl::SimpleAtoi(tokens[2], &height)) { + std::cerr << "Can't parse '" << tokens[2] << "' as an integer." + << std::endl; + return; + } + } + + auto extracted_module = ExtractModule(instr, height); + std::cout << extracted_module->ToString( + HloPrintOptions::ShortParsable().set_print_backend_config( + show_backend_config)) + << std::endl; +} + +// Checks if there is a use-def path from `from` to `to`. +bool ExistsPathFromTo(const HloInstruction* from, const HloInstruction* to) { + std::unordered_set visited; + std::vector to_visit = {from}; + while (!to_visit.empty()) { + auto* n = to_visit.back(); + if (n == to) { + return true; + } + to_visit.pop_back(); + visited.insert(n); + for (auto* user : n->users()) { + if (!visited.count(user)) { + to_visit.push_back(user); + } + } + } + return false; +} + +void DisplayGraphHandle(const Options &opts, const string& handle) { + std::cout << handle << std::endl; + + // If it is a url, try to open it up in the user's browser too. + if (absl::StartsWithIgnoreCase(handle, "http://") || + absl::StartsWithIgnoreCase(handle, "https://") || + absl::StartsWithIgnoreCase(handle, "file://")) { + const char* browser_bin = opts.browser.empty() ? "/usr/bin/sensible-browser" + : opts.browser.c_str(); + tensorflow::SubProcess p; + p.SetProgram(browser_bin, {browser_bin, handle}); + p.Start(); + } else if (handle.empty()) { + std::cerr << "Unable to render graph, perhaps due to graphviz server " + "timeout. Run with --logtostderr to see." + << std::endl; + } else { + std::cerr << "\nExpected a URL, but got strange graph result (dumped " + "above). If this isn't what you expected, maybe file a bug?" + << std::endl; + } +} + +void DoAllPathsCommand(const Options& opts, const HloModule& module, + const std::vector& tokens) { + if (tokens.size() > 4) { + std::cerr << R"(Illegal input. Enter e.g. "allpaths %add.4 %subtract.2" or +"allpaths add.4 subtract.2 42.)" + << std::endl; + return; + } + + int64 max_nodes = kDefaultMaxNumNodesInAllPaths; + if (tokens.size() == 4 && !absl::SimpleAtoi(tokens[3], &max_nodes)) { + std::cerr << "Can't parse '" << tokens[3] << "' as an integer." + << std::endl; + return; + } + + const HloInstruction* n1 = FindInstruction(module, tokens[1]); + if (!n1) { + std::cerr << "Couldn't find HloInstruction named " << tokens[1]; + return; + } + const HloInstruction* n2 = FindInstruction(module, tokens[2]); + if (!n2) { + std::cerr << "Couldn't find HloInstruction named " << tokens[2]; + return; + } + + // Is there a path from n1 to n2, or vice versa? + const HloInstruction* from; + const HloInstruction* to; + if (ExistsPathFromTo(n1, n2)) { + from = n1; + to = n2; + } else if (ExistsPathFromTo(n2, n1)) { + from = n2; + to = n1; + } else { + std::cerr << "No path from/to " << tokens[1] << " to/from " << tokens[2]; + return; + } + DisplayGraphHandle(opts, hlo_graph_dumper::DumpAllPathsFromTo( + *from, *to, max_nodes, /*show_backend_config=*/show_backend_config)); +} + +// Plot a given instruction neighborhood or computation with graphviz. +void DoPlotCommand(const Options& opts, const HloModule& module, + const std::vector& tokens) { + string node_name = tokens[0]; + + // Find the node with the given name. + const HloInstruction* instr = FindInstruction(module, node_name); + const HloComputation* comp = FindComputation(module, node_name); + if (!instr && !comp) { + std::cerr << "Couldn't find HloInstruction or HloComputation named " + << node_name << "." << std::endl; + return; + } + + uint64 graph_width = kDefaultWidth; + absl::flat_hash_set boundary; + if (tokens.size() >= 2) { + if (comp) { + std::cerr << "Can only use graph-size parameter with instructions, but " + << node_name << " is a computation." << std::endl; + return; + } + + int bound_index = 1; + // Get the if present. + if (absl::SimpleAtoi(tokens[bound_index], &graph_width)) { + bound_index++; + } else { + // not found, need to reset graph_width. + graph_width = kDefaultWidth; + } + // Get the '/'. + if (bound_index < tokens.size()) { + // This token must be a '/'. + if (tokens[bound_index] != "/") { + std::cerr << "Expect a /, but get a '" << tokens[bound_index] << "'." + << std::endl; + return; + } + bound_index++; + } + // Get the boundary nodes. + while (bound_index < tokens.size()) { + string bnode_name = tokens[bound_index]; + const HloInstruction* binstr = FindInstruction(module, bnode_name); + if (!binstr) { + std::cerr << "Couldn't find HloInstruction named " << bnode_name << "." + << std::endl; + return; + } + boundary.insert(binstr); + bound_index++; + } + } + + // Generate the graph and print the resulting string, which should be a + // graphviz url. + if (comp) { + DisplayGraphHandle(opts, hlo_graph_dumper::DumpGraph( + *comp, "", comp->parent()->config().debug_options(), nullptr, + /*show_backend_config=*/show_backend_config)); + } else { + DisplayGraphHandle(opts, hlo_graph_dumper::DumpNeighborhoodAround( + *instr, graph_width, + /*show_backend_config=*/show_backend_config, + /*boundary=*/boundary)); + } +} + +// Run the main event loop, reading user commands and processing them. +void InteractiveDumpGraphs(const Options& opts, const HloModule& module) { + // This is an interactive tool, but some may use `extract` in non-tty + // environment anyway. Give them a clean hlo dump. + if (isatty(fileno(stdin))) { + std::cout << "\n\nLoaded module " << module.name() << "." << std::endl; + DoHelpCommand(); + } + for (string line; ReadLine("\ncommand: ", &line);) { + if (line.empty()) { + std::cout << R"(Enter e.g. "fusion.1 3" or "add.8".)" << std::endl + << R"(Enter "help" for help; ^D, "quit", or "exit" to exit.)" + << std::endl; + continue; + } + std::vector tokens = absl::StrSplit(line, ' ', absl::SkipEmpty()); + if (tokens[0] == "quit" || tokens[0] == "exit") { + break; + } else if (tokens[0] == "help") { + DoHelpCommand(); + } else if (tokens[0] == "backend_config") { + DoBackendConfigCommand(tokens); + } else if (tokens[0] == "list") { + if (tokens.size() > 1 && tokens[1] == "computations") { + DoListComputationsCommand(module, tokens); + } else { + DoListCommand(module, tokens); + } + } else if (tokens[0] == "info") { + DoInfoCommand(module, tokens); + } else if (tokens[0] == "extract") { + DoExtractCommand(module, tokens); + } else if (tokens[0] == "allpaths") { + DoAllPathsCommand(opts, module, tokens); + } else { + DoPlotCommand(opts, module, tokens); + } + } +} + +void CheckFlags(const Options &opts) { + std::vector nonempty_proto_flags; + if (!opts.hlo_proto.empty()) { + nonempty_proto_flags.push_back("--hlo_proto"); + } + if (!opts.hlo_snapshot.empty()) { + nonempty_proto_flags.push_back("--hlo_snapshot"); + } + if (!opts.hlo_text.empty()) { + nonempty_proto_flags.push_back("--hlo_text"); + } + switch (nonempty_proto_flags.size()) { + case 1: + // We're good to go. + break; + case 0: + LOG(FATAL) << "Need one of the following options: " + << absl::StrJoin(nonempty_proto_flags, ", "); + default: + LOG(FATAL) << "Can only specify one of " + << absl::StrJoin(nonempty_proto_flags, ", "); + } +} + +void RealMain(const Options& opts) { + if (!isatty(fileno(stdin))) { + LOG(ERROR) << "\n\n*****************************************\n" + << "This is an interactive tool, but stdin is not a tty.\n" + << "*****************************************\n\n"; + } + + CheckFlags(opts); + + std::unique_ptr module; + if (!opts.hlo_snapshot.empty()) { + HloSnapshot snapshot; + TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), + opts.hlo_snapshot, &snapshot)) + << "Can't open, read, or parse HloSnapshot proto at " + << opts.hlo_snapshot; + auto config = + HloModule::CreateModuleConfigFromProto(snapshot.hlo().hlo_module(), + xla::GetDebugOptionsFromFlags()) + .ValueOrDie(); + module = HloModule::CreateFromProto(snapshot.hlo().hlo_module(), config) + .ValueOrDie(); + } else if (!opts.hlo_proto.empty()) { + module = HloRunner::ReadModuleFromBinaryProtoFile( + opts.hlo_proto, xla::GetDebugOptionsFromFlags()) + .ValueOrDie(); + } else if (!opts.hlo_text.empty()) { + module = HloRunner::ReadModuleFromHloTextFile( + opts.hlo_text, xla::GetDebugOptionsFromFlags()) + .ValueOrDie(); + } + + // If a platform was specified, compile the module for that platform. + if (!opts.platform.empty()) { + se::Platform* platform = + PlatformUtil::GetPlatform(opts.platform).ValueOrDie(); + LOG(INFO) << "Compiling module for " << platform->Name(); + + se::StreamExecutor* executor = + platform->ExecutorForDevice(/*ordinal=*/0).ValueOrDie(); + auto compiler = Compiler::GetForPlatform(platform).ValueOrDie(); + module = compiler + ->RunHloPasses(std::move(module), executor, + /*device_allocator=*/nullptr) + .ValueOrDie(); + auto executable = compiler + ->RunBackend(std::move(module), executor, + /*device_allocator=*/nullptr) + .ValueOrDie(); + InteractiveDumpGraphs(opts, executable->module()); + } else { + InteractiveDumpGraphs(opts, *module); + } +} + +} // namespace +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + xla::tools::Options opts; + opts.browser = "/usr/bin/sensible-browser"; + bool need_help = false; + const std::vector flag_list = { + tensorflow::Flag("hlo_snapshot", &opts.hlo_snapshot, + "HloSnapshot proto to interactively dump to graphviz"), + tensorflow::Flag("hlo_proto", &opts.hlo_proto, + "XLA hlo proto to interactively dump to graphviz"), + tensorflow::Flag("hlo_text", &opts.hlo_text, + "XLA hlo proto to interactively dump to graphviz"), + tensorflow::Flag("platform", &opts.platform, + "Platform to compile for: CPU, CUDA, etc"), + tensorflow::Flag("browser", &opts.browser, + "Path to web browser used to display produced graphs."), + tensorflow::Flag("help", &need_help, + "Prints this help message"), + }; + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc != 1 || !parse_ok || need_help) { + LOG(QFATAL) << usage; + } + xla::tools::RealMain(opts); + return 0; +} diff --git a/tensorflow/compiler/xla/tools/interactive_graphviz_test.sh b/tensorflow/compiler/xla/tools/interactive_graphviz_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..b3e43aa7da062547fb5f187b885e997fc44bbb65 --- /dev/null +++ b/tensorflow/compiler/xla/tools/interactive_graphviz_test.sh @@ -0,0 +1,19 @@ +#! /bin/bash +# /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ==============================================================================*/ + +# This is a placeholder for a compile-only test for intractive_graphviz tool. + +exit 0 diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index ff2c3399928c0e6339304323c4f93e212933a340..d66561315b4ad7a5e3f1f7b1bc1e557b71da6705 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" +#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -73,14 +74,24 @@ namespace { // fields. struct Options { string fake_infeed_shape; - bool generate_fake_infeed = false; + string fake_outfeed_shape; + + // generate_fake_infeed == true is a safe default: If the model has 0 or 1 + // infeeds, then it will work like normal. If the model has more than one + // infeed, it will be an error, but that wouldn't have worked anyway if you + // hadn't passed generate_fake_infeed. + // + // Same for generate_fake_outfeed. + bool generate_fake_infeed = true; + bool generate_fake_outfeed = true; + bool use_fake_data = false; bool print_result = true; int num_runs = 1; }; -std::unique_ptr CompileExecutable(const HloSnapshot& module, - LocalClient* client) { +StatusOr> CompileExecutable( + const HloSnapshot& module, LocalClient* client) { XlaComputation computation(module.hlo().hlo_module()); std::vector argument_layouts; argument_layouts.reserve( @@ -91,9 +102,86 @@ std::unique_ptr CompileExecutable(const HloSnapshot& module, argument_layouts.push_back(Shape(param)); argument_layout_ptrs.push_back(&argument_layouts.back()); } - return client - ->Compile(computation, argument_layout_ptrs, ExecutableBuildOptions()) - .ValueOrDie(); + ExecutableBuildOptions exec_build_options; + *exec_build_options.mutable_debug_options() = GetDebugOptionsFromFlags(); + return client->Compile(computation, argument_layout_ptrs, exec_build_options); +} + +absl::optional GetXfeedShape(bool is_infeed, + const HloModuleProto& module, + const Options& opts) { + std::vector xfeed_instrs; + for (const auto& comp : module.computations()) { + for (const auto& instruction : comp.instructions()) { + if (instruction.opcode() == HloOpcodeString(is_infeed + ? HloOpcode::kInfeed + : HloOpcode::kOutfeed)) { + xfeed_instrs.push_back(instruction); + } + } + } + + auto log_xfeed_instrs = [&] { + for (const auto& infeed : xfeed_instrs) { + LOG(ERROR) << " " << ShapeUtil::HumanString(Shape(infeed.shape())) << " " + << infeed.name(); + } + }; + + auto find_instruction_from_id_or_die = [&](int64 id) { + for (const auto& comp : module.computations()) { + for (const auto& instruction : comp.instructions()) { + if (instruction.id() == id) { + return instruction; + } + } + } + LOG(FATAL) << "No instruction with id " << id; + }; + + absl::optional xfeed_shape; + string xfeed_name = is_infeed ? "infeed" : "outfeed"; + string fake_xfeed_shape = + is_infeed ? opts.fake_infeed_shape : opts.fake_outfeed_shape; + bool generate_fake_xfeed = + is_infeed ? opts.generate_fake_infeed : opts.generate_fake_outfeed; + if (!fake_xfeed_shape.empty()) { + xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie(); + } else if (generate_fake_xfeed) { + CHECK_LT(xfeed_instrs.size(), 2) + << "--generate_fake_" << xfeed_name + << " only works if the model has 0 or 1 " << xfeed_name << " ops."; + if (xfeed_instrs.empty()) { + LOG(INFO) << "Not generating fake " << xfeed_name + << " shape; model has no " << xfeed_name << "s."; + } else if (xfeed_instrs.size() == 1) { + // kInfeed instructions should have a shape (buffer, token). kOutfeed + // instructions should have operand 0 of shape `buffer`. We want to xfeed + // just `buffer`. + xfeed_shape = is_infeed + ? Shape(xfeed_instrs.front().shape()).tuple_shapes(0) + : Shape(find_instruction_from_id_or_die( + xfeed_instrs.front().operand_ids(0)) + .shape()); + LOG(INFO) << "Generating fake " << xfeed_name << " with inferred shape: " + << ShapeUtil::HumanString(*xfeed_shape); + } else { + LOG(ERROR) << "--generate_fake_" << xfeed_name + << " only works if the model has 0 or 1 " << xfeed_name + << " ops, but this model has " << xfeed_instrs.size() + << " of them:"; + log_xfeed_instrs(); + LOG(FATAL) << "Can't run model with --generate_fake_infeed."; + } + } else if (!xfeed_instrs.empty()) { + LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name + << " instruction(s), but neither --generate_fake_" << xfeed_name + << " nor --fake_" << xfeed_name + << "_shape was specified. Execution will likely hang."; + log_xfeed_instrs(); + } + + return xfeed_shape; } // Invokes the given computation passing arbitrary data for every (unbound) @@ -118,7 +206,12 @@ StatusOr ReplayComputation(const HloSnapshot& module, std::vector> global_data_arguments; std::vector argument_ptrs; if (opts.use_fake_data) { - global_data_arguments = MakeFakeArgumentsOrDie(computation, client); + // Run fake computations with debug options ignoring XLA_FLAGS. Users very + // likely want XLA_FLAGS only to apply to the "real" computation being run, + // not to the fake computations we use for generating arguments. + auto debug_opts = DefaultDebugOptionsIgnoringFlags(); + global_data_arguments = + MakeFakeArgumentsOrDie(computation, client, &debug_opts); for (const auto& data : global_data_arguments) { argument_ptrs.push_back( client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0) @@ -137,55 +230,37 @@ StatusOr ReplayComputation(const HloSnapshot& module, } } - bool provide_infeed = false; - Shape infeed_shape; - if (!opts.fake_infeed_shape.empty()) { - StatusOr shape_status = - ShapeUtil::ParseShapeString(opts.fake_infeed_shape); - TF_CHECK_OK(shape_status.status()); - infeed_shape = std::move(shape_status).ValueOrDie(); - provide_infeed = true; - } else if (opts.generate_fake_infeed) { - for (const auto& comp : computation.proto().computations()) { - for (const auto& instruction : comp.instructions()) { - if (instruction.opcode() == HloOpcodeString(HloOpcode::kInfeed)) { - CHECK(!provide_infeed) - << "--generate_fake_infeed only works if the model has 0 or 1 " - "infeed ops, but this one has >= 2."; - provide_infeed = true; - infeed_shape = Shape(instruction.shape()); - LOG(INFO) << "Generating fake infeed shape for inferred shape: " - << ShapeUtil::HumanString(infeed_shape); - } - } - } - } - // We only instantiate the thread pool if the user has requested that a - // concurrent infeed occur via the fake_infeed_shape, or when - // --generate_fake_infeed is passed and there exists an infeed operation in - // the HloSnapshot. - absl::optional pool; - Literal data; - if (provide_infeed) { - data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie(); + if (absl::optional infeed_shape = GetXfeedShape( + /*is_infeed=*/true, computation.proto(), opts)) { + auto infeed_data = std::make_shared( + std::move(MakeFakeLiteral(*infeed_shape)).ValueOrDie()); + xla::gpu::GetOrCreateInfeedManager() + ->RegisterBeforeGetNextDestinationCallback([infeed_data, client] { + TF_CHECK_OK(client->TransferToInfeed(*infeed_data)); + }); } - auto transfer_infeed = [&data, client]() { - TF_CHECK_OK(client->TransferToInfeed(data)); - }; - if (provide_infeed) { - pool.emplace(tensorflow::Env::Default(), "infeed", - /*num_threads=*/1); - pool->Schedule([transfer_infeed]() { - // There may be several infeed buffers needed, however we don't know how - // many. If we proactively transfer too many infeed buffers, we may run - // out of memory. If we transfer too few infeed buffers, the program will - // hang. Therefore, we register a callback that is called when the infeed - // becomes empty, and in this callback we will transfer another fake - // infeed. - auto infeed_manager = xla::gpu::GetOrCreateInfeedManager(); - infeed_manager->RegisterOnEmptyCallback(transfer_infeed); - transfer_infeed(); - }); + + absl::optional outfeed_thread_pool; + if (absl::optional outfeed_shape = GetXfeedShape( + /*is_infeed=*/false, computation.proto(), opts)) { + // For each an outfeed that runs, enqueue a task that will consume it. We + // need a thread pool because the act of running an outfeed blocks on there + // being a destination available, and the act of making a destination + // available blocks on there being outfeed data available. + outfeed_thread_pool.emplace(tensorflow::Env::Default(), "infeed", + /*num_threads=*/1); + auto consume_outfeed = [client, outfeed_shape] { + TF_CHECK_OK( + client->TransferFromOutfeedLocal(*outfeed_shape, /*device_ordinal=*/0) + .status()); + VLOG(1) << "Received outfeed data of shape " + << ShapeUtil::HumanStringWithLayout(*outfeed_shape); + }; + xla::gpu::GetOrCreateOutfeedManager() + ->RegisterBeforeGetNextDestinationCallback( + [consume_outfeed, &outfeed_thread_pool] { + outfeed_thread_pool->Schedule(consume_outfeed); + }); } // Do not attempt to run the executable if num_runs is less than 1. @@ -254,7 +329,10 @@ StatusOr ParseInputFile(const string& filename, fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str()); string contents; TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents)); - StatusOr> module = ParseHloString(contents); + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsFromFlags()); + StatusOr> module = + ParseHloString(contents, config); if (module.ok()) { *snapshot.mutable_hlo()->mutable_hlo_module() = module.ValueOrDie()->ToProto(); @@ -282,7 +360,7 @@ int RealMain(absl::Span args, const Options& opts) { // Compile all the modules in parallel. LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel."; - std::vector> executables; + std::vector>> executables; { // ThreadPool CHECK-fails if we give it 0 threads. tensorflow::thread::ThreadPool thread_pool( @@ -299,9 +377,16 @@ int RealMain(absl::Span args, const Options& opts) { LOG(INFO) << "Done compiling; now running the modules."; for (int64 i = 0; i < executables.size(); ++i) { - LocalExecutable* executable = executables[i].get(); + if (!executables[i].ok()) { + LOG(ERROR) << "Compilation failed: " << executables[i].status(); + exit_status = EXIT_FAILURE; + continue; + } + LocalExecutable* executable = executables[i].ValueOrDie().get(); + LOG(ERROR) << "Running iteration " << i; StatusOr result_status = ReplayComputation(snapshots[i], executable, client, opts); + LOG(ERROR) << "iteration complete."; if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", args[i], result_status.status().ToString().c_str()); @@ -346,9 +431,14 @@ int main(int argc, char** argv) { "Number of times to run each computation"), tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), + tensorflow::Flag("fake_outfeed_shape", &opts.fake_outfeed_shape, + "Shape of fake data to outfeed from computation"), tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed, - "Whether a fake infeed shape should be generated " - "derived from the computation"), + "Whether a fake infeed shape should be derived " + "from the computation"), + tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed, + "Whether a fake outfeed shape should be derived " + "from the computation"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc index cdf306dfd1027cf6022c5d8ae844b4308f580e8d..b80d0db8d812380d8144713109d1c05168713c77 100644 --- a/tensorflow/compiler/xla/tools/show_signature.cc +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index b645acb700b0f168112a40c9c72b4669435f717d..daf678f69017b9eb86cbc464a1f33b434021901d 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -41,6 +41,7 @@ using ::tensorflow::uint32; using ::tensorflow::uint64; using complex64 = std::complex; +using complex128 = std::complex; using ::Eigen::half; diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 68cab7387cf1576072f96878b50f07def6862d8b..bb8bbf57c4252b16836553334901a3c896a17f39 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -80,13 +81,9 @@ bool IsPermutation(absl::Span permutation, int64 rank) { if (rank != permutation.size()) { return false; } - std::vector output(permutation.size(), -1); - for (auto index : permutation) { - CHECK_GE(index, 0); - CHECK_LT(index, rank); - output[index] = 0; - } - return std::find(output.begin(), output.end(), -1) == output.end(); + absl::InlinedVector trivial_permutation(rank); + absl::c_iota(trivial_permutation, 0); + return absl::c_is_permutation(permutation, trivial_permutation); } std::vector InversePermutation( diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 6722641e9d2c177440361e6f0d1f6c0804eb7cda..f2fd17dc99455a921bf875aad2a3661b4d456823 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -324,8 +324,7 @@ bool IsIdentityPermutation(absl::Span permutation); template int64 PositionInContainer(const Container& container, int64 value) { - return std::distance(container.begin(), - std::find(container.begin(), container.end(), value)); + return std::distance(container.begin(), absl::c_find(container, value)); } // Formats the container as a comma-separated string. StrAppend must support diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 51c73b3d17e4c32d9a8a14d3055ab56f02922af3..e001cc35f9fcea2783b3952e825838af6bbece72 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -137,25 +138,23 @@ bool HasPadding(const Window& window) { } bool HasSymmetricPadding(const Window& window) { - return std::all_of(window.dimensions().begin(), window.dimensions().end(), - [](const WindowDimension& dim) { - return dim.padding_low() == dim.padding_high(); - }); + return absl::c_all_of(window.dimensions(), [](const WindowDimension& dim) { + return dim.padding_low() == dim.padding_high(); + }); } bool HasSymmetricPadding(const PaddingConfig& padding_config) { - return std::all_of(padding_config.dimensions().begin(), - padding_config.dimensions().end(), - [](const PaddingConfig::PaddingConfigDimension& dim) { - return dim.edge_padding_low() == dim.edge_padding_high(); - }); + return absl::c_all_of(padding_config.dimensions(), + [](const PaddingConfig::PaddingConfigDimension& dim) { + return dim.edge_padding_low() == + dim.edge_padding_high(); + }); } bool HasNegativePadding(const Window& window) { - return std::any_of(window.dimensions().begin(), window.dimensions().end(), - [](const WindowDimension& dim) { - return dim.padding_low() < 0 || dim.padding_high() < 0; - }); + return absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) { + return dim.padding_low() < 0 || dim.padding_high() < 0; + }); } bool HasBaseDilation(const Window& window) { @@ -190,10 +189,9 @@ bool AllOrNoneReversed(const Window& window) { return true; } bool reversed = window.dimensions()[0].window_reversal(); - return std::all_of(window.dimensions().begin(), window.dimensions().end(), - [&](const WindowDimension& dim) { - return dim.window_reversal() == reversed; - }); + return absl::c_all_of(window.dimensions(), [&](const WindowDimension& dim) { + return dim.window_reversal() == reversed; + }); } bool HasDilation(const Window& window) { diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 1439f1bcc5cec39203a7cb4b1f8604e7349382c6..cda2d7c7c6b2403868f6d01a485753fa29a8d95f 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -1,30 +1,47 @@ """Wrapper around cc_proto_library used inside the XLA codebase.""" -load("//tensorflow/core:platform/default/build_config.bzl", - "cc_proto_library") -load("//tensorflow/core:platform/default/build_config_root.bzl", - "if_static") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "cc_proto_library", +) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "if_static", +) +load("//tensorflow:tensorflow.bzl", "if_cuda_is_configured") # xla_proto_library() is a convenience wrapper around cc_proto_library. -def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0, **kwargs): - if kwargs.get('use_grpc_plugin'): - kwargs['use_grpc_namespace'] = True - cc_proto_library(name=name, - srcs=srcs, - deps=deps, - cc_libs = if_static( - ["@protobuf_archive//:protobuf"], - otherwise=["@protobuf_archive//:protobuf_headers"], - ), - protoc="@protobuf_archive//:protoc", - testonly=testonly, - visibility=visibility, - **kwargs) +def xla_proto_library(name, srcs = [], deps = [], visibility = None, testonly = 0, **kwargs): + if kwargs.get("use_grpc_plugin"): + kwargs["use_grpc_namespace"] = True + cc_proto_library( + name = name, + srcs = srcs, + # Append well-known proto dep. As far as I know this is the only way + # for xla_proto_library to access google.protobuf.{Any,Duration,...}. + deps = deps + ["@protobuf_archive//:cc_wkt_protos"], + cc_libs = if_static( + ["@protobuf_archive//:protobuf"], + otherwise = ["@protobuf_archive//:protobuf_headers"], + ), + protoc = "@protobuf_archive//:protoc", + testonly = testonly, + visibility = visibility, + **kwargs + ) -def xla_py_grpc_library(**kwargs): - # Note: we don't currently define any special targets for Python GRPC in OSS. - _ignore = kwargs - pass +def xla_py_proto_library(**kwargs): + # Note: we don't currently define a proto library target for Python in OSS. + _ignore = kwargs + pass +def xla_py_grpc_library(**kwargs): + # Note: we don't currently define any special targets for Python GRPC in OSS. + _ignore = kwargs + pass ORC_JIT_MEMORY_MAPPER_TARGETS = [] + +# We link the GPU plugin into the XLA Python extension if CUDA is enabled. +def xla_python_default_plugins(): + return if_cuda_is_configured(["//tensorflow/compiler/xla/service:gpu_plugin"]) diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index a37eac7fe441d91aa71e1b6fd7b84099fee2215b..925fcbf88c1e8dd81ab1339d292e05eae52e0d13 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -15,11 +15,11 @@ limitations under the License. syntax = "proto3"; -import "tensorflow/compiler/xla/xla_data.proto"; -import "tensorflow/compiler/xla/service/hlo.proto"; - package xla; +import "tensorflow/compiler/xla/service/hlo.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; + // Options for the HLO insert-reduce-precision-operations pass. message HloReducePrecisionOptions { // Where and when the reduce-precision operations will be added. @@ -72,8 +72,7 @@ message DebugOptions { // Path to dump HLO graphs to. string xla_hlo_graph_path = 4; - // Dump HLO graphs as TensorFlow GraphDefs. - bool xla_hlo_dump_as_graphdef = 5; + reserved 5; // Was xla_hlo_dump_as_graphdef // HLO modules matching this regex will be dumped to LOG(INFO). Set to ".*" to // dump *all* HLO modules. @@ -100,6 +99,14 @@ message DebugOptions { // names as specified by the HloPassInterface::name() method. repeated string xla_disable_hlo_passes = 30; + // Disables all HLO passes. Notes that some passes are necessary for + // correctness and the invariants that must be satisfied by "fully optimized" + // HLO are different for different devices and may change over time. The only + // "guarantee", such as it is, is that if you compile XLA and dump the + // optimized HLO for some graph, you should be able to run it again on the + // same device with the same build of XLA. + bool xla_disable_all_hlo_passes = 104; + // Numerical optimization level for the XLA compiler backend; the specific // interpretation of this value is left to the backends. int32 xla_backend_optimization_level = 31; @@ -163,9 +170,7 @@ message DebugOptions { // HLO graph. bool xla_hlo_graph_sharding_color = 92; - // Prefix the name scopes of the TF graph exports with "devX" device - // assignments, if available. - bool xla_hlo_tfgraph_device_scopes = 93; + reserved 93; // Was xla_hlo_tfgraph_device_scopes // If true, the GPU backend is free to use cudnn for HLO batch normalization // ops. @@ -216,6 +221,34 @@ message DebugOptions { // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3). bool xla_gpu_disable_ptxas_optimizations = 103; + // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML) + bool xla_hlo_dump_as_html = 105; + + // Enable fast math with eigen in the HLO evaluator. + bool xla_hlo_evaluator_use_fast_path = 106; + + // Temporary option to allow support for both the R1 and the scalar index + // versions of DynamicSlice and DynamicUpdateSlice. Only used for testing. + bool xla_allow_scalar_index_dynamic_ops = 107; + + enum StepMarkerLocation { + // Generate step mark at each iteration of top level while loop, which + // is assumed to be a training loop. This is the default. + STEP_MARK_AT_ENTRY = 0; + // Generate step mark at program entry. This handles the case where each + // step are done by one or multiple programs execution. Only the first + // program will be tagged for generating step mark at program entry. + STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP = 1; + // No step mark. + STEP_MARK_NONE = 2; + } + // Option to emit a target-specific marker to indicate the start of a training + // step. The location of the marker (if any) is determined by the option + // value. + StepMarkerLocation xla_step_marker_location = 108; + + // Next id: 109 + // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. map xla_backend_extra_options = 500; @@ -245,6 +278,10 @@ message ExecutionOptions { // computation on. The computation will be partitioned across these devices. // If not provided, the default device will be chosen. repeated DeviceHandle device_handles = 5; + + // Number of replicas of the computation to run. If zero, uses the default + // number of replicas for the XLA service. + int32 num_replicas = 6; } message GetDeviceHandlesRequest { @@ -282,8 +319,7 @@ message TransferToInfeedRequest { DeviceHandle device_handle = 3; } -message TransferToInfeedResponse { -} +message TransferToInfeedResponse {} message TransferFromOutfeedRequest { // This optional field directs the service to return the literal in this @@ -302,8 +338,7 @@ message ResetDeviceRequest { DeviceHandle device_handle = 1; } -message ResetDeviceResponse { -} +message ResetDeviceResponse {} message ComputationGraphStatsRequest { HloModuleProto computation = 1; @@ -326,8 +361,7 @@ message UnregisterRequest { repeated GlobalDataHandle data = 1; } -message UnregisterResponse { -} +message UnregisterResponse {} message CompileRequest { // The graph to be compiled. @@ -389,7 +423,7 @@ message WaitForExecutionResponse { message ComputeConstantGraphRequest { HloModuleProto computation = 1; - Layout output_layout = 2; + LayoutProto output_layout = 2; } message ComputeConstantResponse { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 85ec83437a10d973687a7fb84285c2e2541a53c7..226299a7186ef0acb41f6d01fdeffeee06f13d4d 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -56,6 +56,7 @@ enum PrimitiveType { // Complex values of fixed width. C64 = 15; // Paired F32 (real, imag), as in std::complex. + C128 = 18; // Paired F64 (real, imag), as in std::complex. // A tuple is a polymorphic sequence; e.g. a shape that holds different // sub-shapes. They are used for things like returning multiple values from a @@ -75,7 +76,7 @@ enum PrimitiveType { // primitive type will have empty dimensions and tuple_shapes fields. TOKEN = 17; - // Next = 18 + // Next = 19 } // Describes the padding configuration for Pad operation. The padding amount on @@ -100,6 +101,8 @@ message PaddingConfig { // A format specifies the method used by a layout to store an array in memory. enum Format { + // TODO(b/120869032): Rename this to FORMAT_NONE or something else which + // better corresponds to its meaning. INVALID_FORMAT = 0; // The default layout, with exactly one storage location per element. DENSE = 1; @@ -109,8 +112,9 @@ enum Format { } // Describes a tile used in tiling-based layout. Refer to -// g3doc/layout_with_tiling.md for details about tiling-based layout. -message Tile { +// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for +// details about tiling-based layout. +message TileProto { // Number of elements in each dimension of the tile. It's ordered from the // most major dimension of the tile to the most minor dimension of the tile. // The dimensions correspond to a suffix of the dimensions of the shape being @@ -128,7 +132,7 @@ message Tile { // See the XLA documentation for more information on shapes and layouts. // // LINT.IfChange -message Layout { +message LayoutProto { // The method used to store the data in memory. The format determines which of // the other fields are used by the layout. Format format = 4; @@ -153,7 +157,7 @@ message Layout { // // TODO(b/119839262): implement tiling in each backend or add Unimplemented // error. - repeated Tile tiles = 6; + repeated TileProto tiles = 6; // Bit size of each element. If the size is bigger than what the element // type requires, the value is stored in the least significant @@ -185,18 +189,27 @@ message ShapeProto { // The element type for this shape. PrimitiveType element_type = 2; - // The size (number of elements) for each dimension. - // In XLA, dimensions are numbered from 0 to N-1 for an - // N-dimensional array. The first element of 'dimensions' is the size of - // dimension 0, the second element is the size of dimension 1, and so forth. - // Empty list indicates a scalar. + // The size (number of elements) for each dimension, or an upper bound on the + // size if the dimension is dynamic. In XLA, dimensions are numbered from 0 + // to N-1 for an N-dimensional array. The first element of 'dimensions' is the + // size of dimension 0, the second element is the size of dimension 1, and so + // forth. Empty list indicates a scalar. + // + // If the respective element in 'is_dimension_dynamic' is true then the value + // in this field represents an upper bound on the size of the dimension. repeated int64 dimensions = 3; // For tuples only, the shapes of constitutent shapes in the tuple sequence. repeated ShapeProto tuple_shapes = 4; // The layout used to back this shape. - Layout layout = 5; + LayoutProto layout = 5; + + // For arrays, this indicates whether or not each dimension is + // dynamically-sized. The number of elements in this repeated field should be + // zero (indicating that no dimensions are dynamic) or equal to the number of + // elements in the 'dimensions' field. + repeated bool is_dynamic_dimension = 6; // Important: if any field is added, be sure to modify ShapeUtil::Equal(), // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for @@ -355,6 +368,7 @@ message LiteralProto { repeated float f32s = 8; repeated double f64s = 9; repeated float c64s = 12; // Stored as interleaved real, imag floats. + repeated double c128s = 18; // Stored as interleaved real, imag doubles. repeated LiteralProto tuple_literals = 10; // The F16s, BF16s, U16s and S16s are encoded in little endian byte order bytes f16s = 11; @@ -362,7 +376,7 @@ message LiteralProto { bytes u16s = 16; bytes s16s = 17; repeated int64 sparse_indices = 14; - // Next = 18 + // Next = 19 } message WindowDimension { @@ -531,6 +545,26 @@ enum RandomDistribution { // Next: 4 } +message TriangularSolveOptions { + // If true, solves ax = b. If false, solves xa = b. + bool left_side = 1; + + // If true, 'a' is lower triangular. If false, 'a' is upper triangular. + bool lower = 2; + + // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed. + bool unit_diagonal = 3; + + // Should we transpose or use the adjoint of 'a'? + enum Transpose { + TRANSPOSE_INVALID = 0; + NO_TRANSPOSE = 1; // Don't transpose 'a'. + TRANSPOSE = 2; // Transpose 'a'. + ADJOINT = 3; // Complex conjugate and transpose 'a'. + }; + Transpose transpose_a = 4; +} + message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal, @@ -590,3 +624,15 @@ message PrecisionConfig { // Next: 2 } + +// Describes whether all data-parallelism replicas will receive the same +// parameter data at each buffer. +message ParameterReplication { + // A list of boolean values for the flattened leaf buffers. Each value + // indicates whether the corresponding leaf buffer is replicated. + // + // If this field is empty, it means no buffer is replicated. Otherwise, the + // number of elements in this field must match the number of leaf buffers in + // the HLO instruction's shape. + repeated bool replicated_at_leaf_buffers = 1; +} diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index 2dae746d034a1bf52e84de74dfb0c6e23aaed4d1..b2718c5c283358d98da175a8d3b21bb1f2b01c75 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -11,9 +11,15 @@ package( load( "//tensorflow:tensorflow.bzl", + "tf_custom_op_py_library", "tf_gen_op_libs", + "tf_gen_op_wrapper_py", ) load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) xla_proto_library( name = "xrt_proto", @@ -27,6 +33,12 @@ xla_proto_library( ], ) +tf_proto_library_py( + name = "xrt_proto", # bzl adds a _py suffix + srcs = ["xrt.proto"], + visibility = ["//visibility:public"], +) + cc_library( name = "xrt_utils", srcs = [ @@ -78,6 +90,25 @@ tf_gen_op_libs( ], ) +tf_gen_op_wrapper_py( + name = "xrt_ops_wrapper_py", + out = "xrt_ops.py", + deps = [ + ":xrt_compile_ops_op_lib", + ":xrt_execute_op_op_lib", + ":xrt_state_ops_op_lib", + ], +) + +tf_custom_op_py_library( + name = "xrt_ops", + kernels = ["//tensorflow/compiler/xrt/kernels:xrt_ops"], + visibility = ["//visibility:public"], + deps = [ + ":xrt_ops_wrapper_py", + ], +) + cc_library( name = "xrt_server", visibility = ["//visibility:public"], diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 67f475846e5f16060c1080759b0acb4216c4e72b..1e325191bba828e3d5e4599f87dcf4f4d0674945 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -11,20 +11,15 @@ cc_library( name = "xrt_state_ops", hdrs = ["xrt_state_ops.h"], deps = [ + "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt:xrt_utils", "//tensorflow/core:core_cpu_internal", @@ -55,14 +50,18 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xrt:xrt_compile_ops_op_lib", + "//tensorflow/compiler/xrt:xrt_execute_op_op_lib", "//tensorflow/compiler/xrt:xrt_proto", + "//tensorflow/compiler/xrt:xrt_state_ops_op_lib", "//tensorflow/compiler/xrt:xrt_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/stream_executor:stream_executor_headers_lib", + "//tensorflow/stream_executor:stream_executor_headers", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 2ccdf0f02d840600d5e0649c4805e3672d4a1286..b791519c09758a4f4124c95add5351a9433ecb8f 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -68,9 +68,11 @@ class XRTCompileOp : public OpKernel { Status CompilationCacheKey(const xrt::XLAComputation& computation, string* key) { - string serialized; - TF_RET_CHECK(SerializeToStringDeterministic(computation, &serialized)); - uint64 fingerprint = Fingerprint64(serialized); + const size_t size = computation.ByteSizeLong(); + auto serialized = absl::make_unique(size); + TF_RET_CHECK( + SerializeToBufferDeterministic(computation, serialized.get(), size)); + uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size)); *key = absl::StrCat(fingerprint); return Status::OK(); } @@ -215,11 +217,6 @@ XRTReleaseCompilationRefOp::~XRTReleaseCompilationRefOp() = default; void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XRTReleaseCompilationRefOp::Compute"; - const Tensor& key_tensor = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(key_tensor.shape()), - errors::Internal("computation key should be a string scalar")); - int64 uid = key_tensor.scalar()(); - ResourceMgr* rm; OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm)); @@ -230,9 +227,13 @@ void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) { kXRTCompilationCacheResourceName, &cache)); core::ScopedUnref cache_unref(cache); - OP_REQUIRES_OK(ctx, cache->Release(uid)); - - VLOG(2) << "Released computation handle " << uid; + const Tensor& keys_tensor = ctx->input(0); + auto flat_keys = keys_tensor.flat(); + for (int64 i = 0; i < flat_keys.size(); ++i) { + int64 key = flat_keys(i); + OP_REQUIRES_OK(ctx, cache->Release(key)); + VLOG(2) << "Released computation handle " << key; + } } } // namespace diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 751329eefc33f3372335c805233dafabbf42bf36..42ef88168af4b6f391ffc2e69ab4c4000d1cbee1 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -19,10 +19,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xrt/xrt.pb.h" #include "tensorflow/compiler/xrt/xrt_compilation_cache.h" #include "tensorflow/compiler/xrt/xrt_device.h" @@ -228,8 +228,27 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( shaped_buffer, device_ref.backend(), device_ref.device_ordinal(), &output_tuple)); + + // The ScopedShapedBuffer returned by the executable Run() API, in case of + // input/output buffer aliasing, might have holes in it, which need to be + // filled using the proper input tuples buffers which are the source of + // aliasing. + const xla::HloInputOutputAliasConfig& input_output_alias = + executable->executable()->module().input_output_alias_config(); + auto alias_function = + [&](const xla::ShapeIndex& output_index, + const xla::HloInputOutputAliasConfig::Alias& alias) -> Status { + TF_RET_CHECK(alias.parameter_number < input_tuples.size()); + return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias + ? output_tuple->AliasBufferFrom( + *input_tuples[alias.parameter_number], + alias.parameter_index, output_index) + : Status::OK(); + }; + TF_RETURN_IF_ERROR(input_output_alias.ForEachAliasWithStatus(alias_function)); + if (config_proto.return_exploded_tuple() && - xla::ShapeUtil::IsTuple(output_tuple->on_device_shape())) { + output_tuple->on_device_shape().IsTuple()) { int64 tuple_element_count = xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape()); Tensor* output_tensor; diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index 3258286c10665225aab917107ffa614459c53f3d..343f43b7159b55bad184eed2cada55c76085ffa0 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -37,6 +37,17 @@ REGISTER_KERNEL_BUILDER(Name("XRTAllocate") .HostMemory("handle"), XRTAllocateOp); +REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") + .Device(DEVICE_XLA_GPU) + .HostMemory("inputs") + .HostMemory("handle"), + XRTAllocateFromTensorOp); +REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") + .Device(DEVICE_XLA_CPU) + .HostMemory("inputs") + .HostMemory("handle"), + XRTAllocateFromTensorOp); + REGISTER_KERNEL_BUILDER(Name("XRTSubTuple") .Device(DEVICE_XLA_GPU) .HostMemory("base_handle") @@ -111,6 +122,17 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") .HostMemory("literal"), XRTReadLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") + .Device(DEVICE_XLA_GPU) + .HostMemory("handles") + .HostMemory("tensors"), + XRTReadToTensorOp); +REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") + .Device(DEVICE_XLA_CPU) + .HostMemory("handles") + .HostMemory("tensors"), + XRTReadToTensorOp); + REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") .Device(DEVICE_XLA_GPU) .HostMemory("handle"), @@ -120,4 +142,9 @@ REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") .HostMemory("handle"), XRTReleaseAllocationOp); +REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_GPU), + XRTReleaseAllAllocationsOp); +REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_CPU), + XRTReleaseAllAllocationsOp); + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 26a58fa42d8b730b365b11d2e5608e9945497763..6af73ecc85351a9b38ba526db076e9176d1cb2f1 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -19,10 +19,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ #define TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ +#include #include #include +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -30,11 +35,13 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt.pb.h" #include "tensorflow/compiler/xrt/xrt_device.h" #include "tensorflow/compiler/xrt/xrt_state.h" +#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -183,9 +190,7 @@ class XRTAllocateOp : public OpKernel { // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. class DeviceAccessor::ScopedRef device_ref; - OP_REQUIRES_OK(ctx, - DeviceAccessor::InitScopedRef( - ctx, allocation_proto.device_ordinal(), &device_ref)); + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); XRTTupleAllocation* allocation; OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( @@ -202,6 +207,110 @@ class XRTAllocateOp : public OpKernel { } }; +// Op that allocates memory for a tensor (with optional layout) and transfers it +// to the device, returning an allocation handle. +template +class XRTAllocateFromTensorOp : public OpKernel { + public: + explicit XRTAllocateFromTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + bool make_tuple = false; + OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple)); + std::vector minor_to_major; + if (ctx->HasAttr("layouts")) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major)); + } + OP_REQUIRES( + ctx, tf_shapes_.size() == dtypes_.size(), + errors::InvalidArgument("shapes and dtypes must be the same length")); + std::vector xla_shapes; + xla_shapes.reserve(tf_shapes_.size()); + for (int i = 0; i < tf_shapes_.size(); i++) { + xla::Shape xla_shape; + OP_REQUIRES_OK( + ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape)); + xla_shapes.push_back(std::move(xla_shape)); + } + if (xla_shapes.size() > 1 || make_tuple) { + shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes); + } else { + shape_.Swap(&xla_shapes.front()); + } + if (!minor_to_major.empty()) { + xla::Shape shape_with_layouts; + OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major, + /*layout_func=*/nullptr, + &shape_with_layouts)); + shape_.Swap(&shape_with_layouts); + } + } + + ~XRTAllocateFromTensorOp() override = default; + XRTAllocateFromTensorOp(const XRTAllocateFromTensorOp&) = delete; + XRTAllocateFromTensorOp& operator=(const XRTAllocateFromTensorOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTAllocateFromTensorOp::Compute"; + + OpInputList values; + OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values)); + OP_REQUIRES(ctx, values.size() == tf_shapes_.size(), + errors::InvalidArgument( + "Wrong number of inputs to XRTAllocateFromTensor: ", + values.size(), " vs. ", tf_shapes_.size())); + + std::vector tensors_data; + for (size_t i = 0; i < values.size(); ++i) { + const Tensor& input_tensor = values[i]; + OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i], + errors::InvalidArgument( + "Input tensor type and input dtype do not match")); + // We allow the requested on-device shape to differ from the shape of the + // input tensor, as long as they have the same number of elements. + OP_REQUIRES( + ctx, + input_tensor.shape().num_elements() == tf_shapes_[i].num_elements(), + errors::InvalidArgument( + "Input tensor must have the number of elements specified " + "in the matching input shape: ", + input_tensor.shape().num_elements(), " vs. ", + tf_shapes_[i].num_elements(), " at index ", i)); + tensors_data.push_back( + static_cast(DMAHelper::base(&input_tensor))); + } + // Use the buffer straight out of the input tensors to create the literal. + xla::BorrowingLiteral literal = + shape_.IsTuple() ? xla::BorrowingLiteral(tensors_data, shape_) + : xla::BorrowingLiteral(tensors_data.front(), shape_); + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + class DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); + + XRTTupleAllocation* allocation; + OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( + literal, device_ref.backend(), + device_ref.device_ordinal(), &allocation)); + + // Intern takes ownership of our reference to allocation. + int64 key; + OP_REQUIRES_OK(ctx, allocation->Intern(rm, &key)); + + Tensor output(DT_INT64, TensorShape({})); + output.scalar()() = key; + ctx->set_output(0, output); + } + + private: + std::vector tf_shapes_; + DataTypeVector dtypes_; + xla::Shape shape_; +}; + // Op that takes a tuple handle input and returns a handle to a sub-tuple of the // input. template @@ -381,7 +490,7 @@ class XRTReadLiteralOp : public OpKernel { OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( ctx, allocation->device_ordinal(), &device_ref)); - xla::Literal literal; + xla::Literal literal(allocation->on_host_shape()); OP_REQUIRES_OK( ctx, allocation->ToLiteral(device_ref.backend(), device_ref.device_ordinal(), &literal)); @@ -393,6 +502,96 @@ class XRTReadLiteralOp : public OpKernel { } }; +// Op that reads a device-resident tuple to host memory and returns it as a +// literal. +template +class XRTReadToTensorOp : public OpKernel { + public: + explicit XRTReadToTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("release_handles", &discard_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); + } + ~XRTReadToTensorOp() override = default; + XRTReadToTensorOp(const XRTReadToTensorOp&) = delete; + XRTReadToTensorOp& operator=(const XRTReadToTensorOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTReadToTensorOp::Compute"; + + const Tensor& handle_tensor = ctx->input(0); + // TODO(phawkins,dlibenzi): accept multiple handles (i.e., vectors, not + // just scalars.) + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), + errors::Internal("computation input should be an int64 scalar")); + int64 allocation_handle = handle_tensor.scalar()(); + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + XRTTupleAllocation* allocation; + OP_REQUIRES_OK( + ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); + core::ScopedUnref allocation_unref(allocation); + + if (discard_) { + VLOG(2) << "Releasing handle " << allocation_handle; + OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager( + rm, allocation_handle)); + } + + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + class DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( + ctx, allocation->device_ordinal(), &device_ref)); + + xla::Shape shape = allocation->on_host_shape(); + int output = 0; + Status status = xla::ShapeUtil::ForEachMutableSubshapeWithStatus( + &shape, + [&](xla::Shape* subshape, const xla::ShapeIndex& index) -> Status { + if (subshape->IsTuple()) return Status::OK(); + + xla::PrimitiveType xla_type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType( + ctx->expected_output_dtype(output), &xla_type)); + if (xla_type != subshape->element_type()) { + return errors::InvalidArgument( + "Type mismatch between buffer type (", subshape->ToString(), + ") and tensor type (", + DataTypeString(ctx->expected_output_dtype(output)), + ") for output tensor ", output); + } + + TensorShape output_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(*subshape, &output_shape)); + + Tensor* output_tensor; + TF_RETURN_IF_ERROR( + ctx->allocate_output(output, output_shape, &output_tensor)); + + XRTTupleAllocation* sub; + TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( + allocation, index, &sub, /*alias_parent_allocation=*/true)); + core::ScopedUnref sub_unref(sub); + + xla::MutableBorrowingLiteral literal; + TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral( + xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor, + &literal)); + TF_RETURN_IF_ERROR(sub->ToLiteral( + device_ref.backend(), device_ref.device_ordinal(), &literal)); + + ++output; + return Status::OK(); + }); + OP_REQUIRES_OK(ctx, status); + } + bool discard_; + DataTypeVector dtypes_; +}; + // Op that writes a new literal value into device-resident memory. template class XRTWriteLiteralOp : public OpKernel { @@ -455,17 +654,37 @@ class XRTReleaseAllocationOp : public OpKernel { void Compute(OpKernelContext* ctx) override { VLOG(1) << "XRTReleaseAllocationOp::Compute"; - const Tensor& allocation_handle = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_handle.shape()), - errors::Internal("handle input should be an int64 scalar")); - int64 key = allocation_handle.scalar()(); - ResourceMgr* rm; OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(rm, key)); + const Tensor& allocation_handle = ctx->input(0); + auto flat_keys = allocation_handle.flat(); + for (int64 i = 0; i < flat_keys.size(); ++i) { + int64 key = flat_keys(i); + OP_REQUIRES_OK(ctx, + XRTTupleAllocation::DeleteFromResourceManager(rm, key)); + VLOG(2) << "Released allocation handle " << key; + } + } +}; + +// Op that discards a handle to device memory. +template +class XRTReleaseAllAllocationsOp : public OpKernel { + public: + explicit XRTReleaseAllAllocationsOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + ~XRTReleaseAllAllocationsOp() override = default; + XRTReleaseAllAllocationsOp(const XRTReleaseAllAllocationsOp&) = delete; + XRTReleaseAllAllocationsOp& operator=(const XRTReleaseAllAllocationsOp&) = + delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTReleaseAllAllocationsOp::Compute"; - VLOG(2) << "Released allocation handle " << key; + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + OP_REQUIRES_OK(ctx, XRTTupleAllocation::ReleaseAllAllocations(rm)); } }; diff --git a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc index 7b3b50c69559f6003a108fdf6a1325dbdbaa80a6..9dd964e5467cd855d67764a512e95a6a18f482e1 100644 --- a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc @@ -44,10 +44,10 @@ REGISTER_OP("XRTReleaseCompilationHandle") .SetShapeFn(tensorflow::shape_inference::NoOutputs) .Doc( R"( -Discards a computation from the compilation cache. The handle cannot be -subsequently used. +Discards one or more computation handles from the compilation cache. +The handle(s) cannot be subsequently used. -'handle' is an id returned from a XRTCompile Op. +'handle' is an ID (or vector of IDs) returned from a XRTCompile Op. )"); } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index a3d63106fa14674a9f5887ccfd908ce17dbc6384..8832270fb2730d1ba64fa069b38f4a04b61773ef 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -26,12 +26,41 @@ REGISTER_OP("XRTAllocate") .SetShapeFn(tensorflow::shape_inference::ScalarShape) .Doc( R"( -Reads a literal proto and transfers it to TPU device memory. +Reads a literal proto and transfers it to device memory. -'allocation' is a serialized xrt::TPUAllocation proto. +'allocation' is a serialized xrt::XLAAllocation proto. 'handle' is an id that can be used in other ops to refer to the allocation. )"); +REGISTER_OP("XRTAllocateFromTensor") + .Input("inputs: dtypes") + .Output("handle: int64") + .Attr("dtypes: list(type)") + .Attr("shapes: list(shape)") + .Attr("layouts: list(int) = []") + .Attr("make_tuple: bool = false") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc( + R"( +Reads a list of tensors with optional layouts, and transfers it to device +memory. + +inputs: The tensors holding the input data. +shapes: The shapes which the tensors should have on device. The i-th shape +corresponds to the i-th input. The shapes, together with the (optional) +layouts, helps creating the fully qualified shape of the data on the device. +The shapes can differ from the corresponding input one, as long as the total +number of elements matches. In other words, it is possible to feed an input +tensor with shape {8} and have a corresponding shape {2,2,2}. +layouts: A vector holding the requested layout in minor-to-major sequence. +If empty, the default layout wil be used. +For a tuple, the layouts vector holds a linearized minor-to-major numbers +for all the tuple leaves, in the order they appear within the tuple. +The elements within the layouts sequence corresponding to a given tuple +subshape can be set to -1, to leave such subshape to the default shape. +handle: An id that can be used in other ops to refer to the allocation. +)"); + REGISTER_OP("XRTSubTuple") .Input("base_handle: int64") .Input("shape_index: int32") @@ -122,15 +151,44 @@ releases the handle. 'literal' is a serialized xla::LiteralProto proto. )"); +REGISTER_OP("XRTReadToTensor") + .Input("handles: int64") + .Attr("release_handles: bool = False") + .Attr("dtypes: list(type)") + .Output("tensors: dtypes") + .SetShapeFn(tensorflow::shape_inference::UnknownShape) + .Doc( + R"( +Copies allocated values from device memory and returns them as zero or more +Tensors. If a handle refers to a non-tuple buffer, a single tensor is returned. +In general, the tensors returned for a handle correspond to an in-order traversal +of a the tuple-tree value referenced by the handle. + +'handles' contains ids returned from Ops that produced on-device allocations. +At present, only a single (scalar) handle is supported. +'dtypes' are the expected types for each `Tensor` to be returned. If the +expected and actual tensor types do not match, an error is returned. +'release_handles': if True, `handles` are released. +'tensors' are the output Tensors. +)"); + REGISTER_OP("XRTReleaseAllocationHandle") .Input("handle: int64") .SetShapeFn(tensorflow::shape_inference::NoOutputs) .Doc( R"( -Discards an allocation from device memory. The handle cannot be subsequently +Discards one or more device memory handles. The handle(s) cannot be subsequently used. -'handle' is the id returned from the Op that produced the on-device allocation. +'handle' is the ID (or a vector of IDs) returned from the Op that produced the +on-device allocation. +)"); + +REGISTER_OP("XRTReleaseAllAllocations") + .SetShapeFn(tensorflow::shape_inference::NoOutputs) + .Doc( + R"( +Discards all the XRT allocations. All the client held handles will be invalid. )"); } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD index be44a3474acdeb9905c1d21b932fa0dd10b5a212..3a19327e5b5d8072fbecdbe10e9959c8491780eb 100644 --- a/tensorflow/compiler/xrt/tests/BUILD +++ b/tensorflow/compiler/xrt/tests/BUILD @@ -24,6 +24,7 @@ cc_library( "//tensorflow/cc:client_session", "//tensorflow/cc:ops", "//tensorflow/cc:scope", + "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index abaa17e50e3f5e47a45f5a8a45fa2090d3efee39..1111f8240512e81c10a42a28c09f5b0a94daf1ee 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -53,6 +55,14 @@ string DeviceFromFlag() { return absl::StrCat("/device:", xla_test_device, ":0"); } +std::vector GetAttrLayout(absl::Span minor_to_mayor) { + std::vector layout; + for (auto dim : minor_to_mayor) { + layout.push_back(static_cast(dim)); + } + return layout; +} + xla::LiteralProto TwoElementTuple() { auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); @@ -96,14 +106,21 @@ xla::LiteralProto FloatMatrix( return array.ToProto(); } +xla::Literal ReadOutputLiteral(const std::vector& outputs, size_t idx) { + xla::LiteralProto response; + CHECK(response.ParseFromString(outputs[idx].scalar()())); + return xla::Literal::CreateFromProto(response).ValueOrDie(); +} + bool CompareLiteralProtos(const xla::LiteralProto& a, const xla::LiteralProto& b) { auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie(); auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); bool equal = l_a == l_b; if (!equal) { - LOG(INFO) << "LiteralProtos don't match: " << a.DebugString() - << " != " << b.DebugString(); + LOG(INFO) << "LiteralProtos don't match:\n" + << a.DebugString() << "\n!=\n" + << b.DebugString(); } return equal; } @@ -113,8 +130,19 @@ bool CompareLiteralToLiteralProto(const xla::Literal& a, auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); bool equal = a == l_b; if (!equal) { - LOG(INFO) << "Literal and LiteralProto don't match " - << a.ToProto().DebugString() << " != " << b.DebugString(); + LOG(INFO) << "Literal and LiteralProto don't match:\n" + << a.ToProto().DebugString() << "\n!=\n" + << b.DebugString(); + } + return equal; +} + +bool CompareLiterals(const xla::Literal& a, const xla::Literal& b) { + bool equal = a == b; + if (!equal) { + LOG(INFO) << "Literals don't match:\n" + << a.ToProto().DebugString() << "\n!=\n" + << b.ToProto().DebugString(); } return equal; } @@ -215,9 +243,122 @@ xla::ProgramShape XlaCompiledProgramShape( ->ComputeProgramShape(); } +TEST(RawApiTest, AllocFromTensor) { + xla::Literal literal = + xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); + Tensor tensor; + TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor)); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + std::vector layout = + GetAttrLayout(literal.shape().layout().minor_to_major()); + ops::XRTAllocateFromTensor::Attrs alloc_attrs = + ops::XRTAllocateFromTensor::Layouts(layout); + auto handle = + ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs); + auto read_back = ops::XRTReadLiteralAndRelease(root, handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); +} + +TEST(RawApiTest, AllocFromTensorTuple) { + xla::Literal literal0 = + xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); + xla::Literal literal1 = + xla::LiteralUtil::CreateR2({{14.0f, -5.0f}, {16.0f, 17.0f}}); + xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1}); + Tensor tensor0; + TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0)); + Tensor tensor1; + TF_ASSERT_OK(LiteralToHostTensor(literal1, DT_FLOAT, &tensor1)); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + std::vector layout = GetShapeLayoutVector(literal.shape()).ValueOrDie(); + ops::XRTAllocateFromTensor::Attrs alloc_attrs = + ops::XRTAllocateFromTensor::Layouts(layout); + auto handle = ops::XRTAllocateFromTensor(root, {tensor0, tensor1}, + {tensor0.shape(), tensor1.shape()}, + alloc_attrs); + auto read_back = ops::XRTReadLiteralAndRelease(root, handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); +} + +TEST(RawApiTest, AllocFromTensorTupleSingle) { + xla::Literal literal0 = + xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); + xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0}); + Tensor tensor0; + TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0)); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + std::vector layout = GetShapeLayoutVector(literal.shape()).ValueOrDie(); + ops::XRTAllocateFromTensor::Attrs alloc_attrs = + ops::XRTAllocateFromTensor::Layouts(layout).MakeTuple(true); + auto handle = ops::XRTAllocateFromTensor(root, {tensor0}, {tensor0.shape()}, + alloc_attrs); + auto read_back = ops::XRTReadLiteralAndRelease(root, handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); +} + +TEST(RawApiTest, AllocFromTensorRelayout) { + xla::Literal literal = + xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); + Tensor tensor; + TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor)); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + // Use inverse array layout with the tensor data above. + std::vector layout({0, 1}); + ops::XRTAllocateFromTensor::Attrs alloc_attrs = + ops::XRTAllocateFromTensor::Layouts(layout); + auto handle = + ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs); + auto read_back = ops::XRTReadLiteralAndRelease(root, handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + // We have sent literal's data (in array layout) with a attribute layout + // {0,1}, so the expected literal read from device needs to be changed + // accordingly. + xla::Literal expected_literal = + xla::LiteralUtil::CreateR2({{4.0f, 6.0f}, {5.0f, 7.0f}}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected_literal, response)); +} + TEST(RawApiTest, AllocAndRewrite) { xrt::XLAAllocation alloc; - alloc.set_device_ordinal(0); *alloc.mutable_value() = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); @@ -259,15 +400,138 @@ TEST(RawApiTest, AllocAndRewrite) { EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response)); - auto release = - ops::XRTReleaseAllocationHandle(root, Input(allocation_handle)); + Tensor release_tensor(DT_INT64, TensorShape({1})); + release_tensor.flat()(0) = allocation_handle; + + auto release = ops::XRTReleaseAllocationHandle(root, release_tensor); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + +TEST(RawApiTest, AllocReleaseMany) { + xrt::XLAAllocation alloc1; + *alloc1.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + xrt::XLAAllocation alloc2; + *alloc2.mutable_value() = + xla::LiteralUtil::CreateR2({{6, 7}, {4, 5}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value1 = + ops::Const(root.WithDevice("/device:CPU:0"), alloc1.SerializeAsString()); + auto value2 = + ops::Const(root.WithDevice("/device:CPU:0"), alloc2.SerializeAsString()); + auto handle1 = ops::XRTAllocate(root, value1); + auto handle2 = ops::XRTAllocate(root, value2); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({handle1, handle2}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 allocation_handle1 = outputs[0].scalar()(); + int64 allocation_handle2 = outputs[1].scalar()(); + + Tensor release_tensor(DT_INT64, TensorShape({2})); + release_tensor.flat()(0) = allocation_handle1; + release_tensor.flat()(1) = allocation_handle2; + + auto release = ops::XRTReleaseAllocationHandle(root, release_tensor); + outputs.clear(); TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, &outputs)); } +TEST(RawApiTest, CompileAndReleaseMany) { + xrt::XLAComputation c1; + auto config1 = c1.mutable_config(); + auto shapes1 = config1->mutable_program_shape(); + *shapes1->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes1->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes1->mutable_result() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + StoreComputationSnapshot(AddAndScale(), c1.mutable_hlo_snapshot()); + + xrt::XLAComputation c2; + auto config2 = c2.mutable_config(); + auto shapes2 = config2->mutable_program_shape(); + *shapes2->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes2->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes2->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})}) + .ToProto(); + StoreComputationSnapshot(AddAndTuple(), c2.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(false); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation1 = + ops::Const(root.WithDevice("/device:CPU:0"), c1.SerializeAsString()); + auto c_handle1 = ops::XRTCompile(root, computation1); + auto computation2 = + ops::Const(root.WithDevice("/device:CPU:0"), c2.SerializeAsString()); + auto c_handle2 = ops::XRTCompile(root, computation2); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({c_handle1.handle, c_handle2.handle}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 compilation_handle1 = outputs[0].scalar()(); + int64 compilation_handle2 = outputs[1].scalar()(); + + Tensor release_tensor(DT_INT64, TensorShape({2})); + release_tensor.flat()(0) = compilation_handle1; + release_tensor.flat()(1) = compilation_handle2; + + auto release = ops::XRTReleaseCompilationHandle(root, release_tensor); + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + +TEST(RawApiTest, AllocAndClearAll) { + xrt::XLAAllocation alloc; + *alloc.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value = + ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); + auto handle = ops::XRTAllocate(root, value); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({handle}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + int64 allocation_handle = outputs[0].scalar()(); + + auto clear_all = ops::XRTReleaseAllAllocations(root); + + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, + {clear_all}, &outputs)); + EXPECT_EQ(outputs.size(), 0); + + auto read_after_clear = ops::XRTReadLiteral(root, Input(allocation_handle)); + EXPECT_EQ(session.Run({read_after_clear}, &outputs).code(), + tensorflow::error::Code::NOT_FOUND); +} + TEST(RawApiTest, ReadAndWriteState) { xrt::XLAAllocation alloc; - alloc.set_device_ordinal(0); *alloc.mutable_value() = TwoElementTuple(); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); @@ -292,7 +556,6 @@ TEST(RawApiTest, ReadAndWriteState) { TEST(RawApiTest, ReadAndWriteStateAutoFree) { xrt::XLAAllocation alloc; - alloc.set_device_ordinal(0); *alloc.mutable_value() = TwoElementTuple(); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); @@ -313,7 +576,6 @@ TEST(RawApiTest, ReadAndWriteStateAutoFree) { TEST(RawApiTest, SubBuffer) { xrt::XLAAllocation alloc; - alloc.set_device_ordinal(0); *alloc.mutable_value() = NestedTuple(); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); @@ -354,10 +616,8 @@ TEST(RawApiTest, SubBuffer) { TEST(RawApiTest, MakeTuple) { xrt::XLAAllocation alloc_0; - alloc_0.set_device_ordinal(0); *alloc_0.mutable_value() = TwoElementTuple(); xrt::XLAAllocation alloc_1; - alloc_1.set_device_ordinal(0); *alloc_1.mutable_value() = ScalarLiteral(); // The trivial tuple that just forwards its input and releases it. @@ -428,10 +688,8 @@ TEST(RawApiTest, MakeTuple) { TEST(RawApiTest, CompileAndExecute) { xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = FloatVector({1.0f, 2.0f}); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = FloatVector({8.0f, 5.0f}); xrt::XLAComputation c; @@ -483,10 +741,8 @@ TEST(RawApiTest, CompileAndExecute) { TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = FloatVector({1.0f, 2.0f}); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = FloatVector({8.0f, 5.0f}); xrt::XLAComputation c; @@ -606,10 +862,8 @@ TEST(RawApiTest, DotGeneralWithLayoutTest) { auto layout = xla::LayoutUtil::MakeLayout({0, 1}); xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = FloatMatrix({{1.0f, 2.0f}, {3.0f, 4.0f}}, layout); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = FloatMatrix({{8.0f}, {5.0f}}, layout); xrt::XLAComputation c; @@ -692,10 +946,8 @@ TEST(RawApiTest, CompileAndExecuteZeroArg) { TEST(RawApiTest, CompileAndExecuteReturnTuple) { xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = FloatVector({1.0f, 2.0f}); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = FloatVector({8.0f, 5.0f}); xrt::XLAComputation c; @@ -745,11 +997,9 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) { xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = xla::LiteralUtil::CreateR0(12.0f).ToProto(); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = xla::LiteralUtil::CreateR0(3.0f).ToProto(); xrt::XLAComputation c; @@ -831,12 +1081,111 @@ TEST(RawApiTest, LeakCompilationReference) { TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs)); } +TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) { + xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::F32, {2}); + xla::Shape shape = + xla::ShapeUtil::MakeTupleShape({element_shape, element_shape}); + xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape( + {element_shape, element_shape, element_shape, element_shape}); + xla::XlaBuilder builder("ReuseBuffer"); + auto param = xla::Parameter(&builder, 0, shape, "param"); + auto p0 = xla::GetTupleElement(param, 0); + auto p1 = xla::GetTupleElement(param, 1); + auto add = xla::Add(p0, p1); + auto sub = xla::Sub(p0, p1); + xla::Tuple(&builder, {add, sub, p0, p1}); + + // Flip the tuple literals in the input handle. + builder.SetUpAlias({1}, 0, {0}); + builder.SetUpAlias({0}, 0, {1}); + + auto computation = builder.Build().ValueOrDie(); + + auto literal0 = xla::LiteralUtil::CreateR1({1.0f, 2.0f}); + auto literal1 = xla::LiteralUtil::CreateR1({5.0f, 9.0f}); + auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1}); + + xrt::XLAAllocation param_alloc; + *param_alloc.mutable_value() = literal.ToProto(); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = shape.ToProto(); + *shapes->mutable_result() = return_shape.ToProto(); + StoreComputationSnapshot(computation, c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(false); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + ClientSession session(root); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto c_data = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, c_data); + auto param_value = ops::Const(root.WithDevice("/device:CPU:0"), + param_alloc.SerializeAsString()); + auto param_handle = ops::XRTAllocate(root, param_value); + TF_ASSERT_OK(root.status()); + + std::vector outputs; + TF_EXPECT_OK(session.Run({param_handle}, &outputs)); + + int64 alloc_handle = outputs[0].scalar()(); + + // Note that we release the result handle immediately, but since we aliased + // the output buffers onto the input allocation ones (held in alloc_handle), + // we can fetch the result from there. + auto result = + ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)}); + auto read_back = ops::XRTReadLiteral(root, result); + auto release = ops::XRTReleaseAllocationHandle( + root.WithControlDependencies(read_back), result); + TF_ASSERT_OK(root.status()); + + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_back}, + {release}, &outputs)); + + xla::Literal exec_literal = ReadOutputLiteral(outputs, 0); + auto exec_literal_parts = exec_literal.DecomposeTuple(); + ASSERT_EQ(exec_literal_parts.size(), 4); + + EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0)); + EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1)); + + // Now we read back the original input handle values, which at this point + // should contain the result of the XLA computation. + auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle)); + TF_ASSERT_OK(root.status()); + auto release_handle = ops::XRTReleaseAllocationHandle( + root.WithControlDependencies(read_handle), Input(alloc_handle)); + TF_ASSERT_OK(root.status()); + + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_handle}, + {release_handle}, &outputs)); + + xla::Literal return_literal = ReadOutputLiteral(outputs, 0); + + auto expected_literal0 = xla::LiteralUtil::CreateR1({6.0f, 11.0f}); + auto expected_literal1 = xla::LiteralUtil::CreateR1({-4.0f, -7.0f}); + // The first element of the computation returned tuple would be the add + // (expected_literal0), but since we flipped the buffers, the sub + // (expected_literal1) should come first. + auto expected_literal = + xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0}); + + EXPECT_TRUE(CompareLiterals(return_literal, expected_literal)); +} + TEST(RawApiTest, CompileAndExecuteWithS64Argument) { xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = xla::LiteralUtil::CreateR0(11031965).ToProto(); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = xla::LiteralUtil::CreateR0(4091934).ToProto(); xrt::XLAComputation c; @@ -850,6 +1199,7 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) { xrt::XRTExecutionConfig e; e.set_release_input_handles(true); e.set_release_compilation_handle(true); + e.set_return_exploded_tuple(true); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); auto e_config = diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index 378bb9246f27b8106310d565435404d7ac260a87..84adee7392825d408dd88dd74dc0c1bc7b06d7c4 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -59,7 +59,7 @@ message XLAComputation { // Literal to allocate space for, and transfer to, device memory. message XLAAllocation { - int32 device_ordinal = 1; + reserved 1; xla.LiteralProto value = 2; } diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.cc b/tensorflow/compiler/xrt/xrt_compilation_cache.cc index d1405eae468492748ae88d842334a922dce272c6..8bf0f28d2233d9e7593365bc42187e327a1c4ac4 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.cc +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.cc @@ -273,6 +273,8 @@ Status XRTCompilationCache::Lookup( return Status::OK(); } -string XRTCompilationCache::DebugString() { return "XRTCompilationCache"; } +string XRTCompilationCache::DebugString() const { + return "XRTCompilationCache"; +} } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.h b/tensorflow/compiler/xrt/xrt_compilation_cache.h index c43d0fc47873abdc82ee937c155bebc346a05f17..7398e847d8b744f947adb03e1bcfd5c0a5b2cc55 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.h +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.h @@ -118,7 +118,7 @@ class XRTCompilationCache : public ResourceBase { // EntryRef holding the program is returned in entry. Status Lookup(int64 uid, std::unique_ptr* entry); - string DebugString() override; + string DebugString() const override; private: // An entry in the compilation cache. The entry is deleted once it has been diff --git a/tensorflow/compiler/xrt/xrt_device.cc b/tensorflow/compiler/xrt/xrt_device.cc index ea40e6c895c4f6af13b74735685f2c342181ada9..34cb64742a20985b29d8e153bbaf5ee184fd385d 100644 --- a/tensorflow/compiler/xrt/xrt_device.cc +++ b/tensorflow/compiler/xrt/xrt_device.cc @@ -43,4 +43,12 @@ namespace tensorflow { return Status::OK(); } +/*static*/ Status XRTGenericDeviceAccessor::InitScopedRef( + OpKernelContext* ctx, ScopedRef* scoped_ref) { + const XlaDevice::Metadata* metadata; + TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata)); + scoped_ref->Acquire(metadata->client()); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_device.h b/tensorflow/compiler/xrt/xrt_device.h index 1e3fddd2a72a3657d1e115375133c244772ea9f3..fb010651d9bf76c540517b9596e472c881241d8a 100644 --- a/tensorflow/compiler/xrt/xrt_device.h +++ b/tensorflow/compiler/xrt/xrt_device.h @@ -59,6 +59,8 @@ class XRTGenericDeviceAccessor { static Status InitScopedRef(OpKernelContext* ctx, int device_ordinal, ScopedRef* scoped_ref); + + static Status InitScopedRef(OpKernelContext* ctx, ScopedRef* scoped_ref); }; } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 31603e044d17baa3ae0ae583f61837811bb12495..1b3bcbea4c1228944a6604fc923228024e74d700 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/random/random.h" @@ -133,7 +132,8 @@ Status AllocateScopedShapedBuffer( XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation, int device_ordinal, xla::DeviceMemoryAllocator* allocator) - : allocation_(allocation), + : size_(allocation.size()), + allocation_(allocation), device_ordinal_(device_ordinal), allocator_(allocator) { if (VLOG_IS_ON(2)) { @@ -181,7 +181,7 @@ XRTTupleAllocation::~XRTTupleAllocation() { } /*static*/ Status XRTTupleAllocation::CreateAndTransfer( - const xla::Literal& literal, xla::Backend* backend, int device_ordinal, + const xla::LiteralBase& literal, xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation) { auto transfer_manager = backend->transfer_manager(); auto allocator = backend->memory_allocator(); @@ -220,12 +220,22 @@ XRTTupleAllocation::~XRTTupleAllocation() { } Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, - xla::Literal* literal) { + xla::MutableLiteralBase* literal) { auto transfer_manager = backend->transfer_manager(); TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); - TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice( - stream.get(), ToShapedBuffer())); - return Status::OK(); + + // Validate the allocation buffers as if nulls gets to + // TransferLiteralFromDevice() a CHECK is issued. + xla::ShapedBuffer shaped_buffer = ToShapedBuffer(); + for (auto& index_buffer : shaped_buffer.buffers()) { + if (index_buffer.second.is_null()) { + return errors::InvalidArgument("Literal buffer at index ", + index_buffer.first.ToString(), + " has been released"); + } + } + return transfer_manager->TransferLiteralFromDevice(stream.get(), + shaped_buffer, *literal); } Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend, @@ -272,6 +282,11 @@ const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() { return rm->Delete(kTupleContainer, key_string); } +/* static */ Status XRTTupleAllocation::ReleaseAllAllocations(ResourceMgr* rm) { + VLOG(1) << "Releasing all XRT held device memory"; + return rm->Cleanup(kTupleContainer); +} + // Helper typedef to make ShapeTree ForEach helper lambda signatures more // readable. They need a type of const T& where in this case T is the // following pointer. @@ -500,11 +515,34 @@ xla::ShapedBuffer XRTTupleAllocation::ToShapedBuffer() { return shaped_buffer; } +Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source, + const xla::ShapeIndex& source_index, + const xla::ShapeIndex& dest_index) { + XRTBufferAllocation* source_buffer = source.buffers_.element(source_index); + XRTBufferAllocation* dest_buffer = buffers_.element(dest_index); + // We allow the destination size being zero, because there are cases where we + // are coming in later filling in null/uninitialized device buffers. + // In all other cases, the size of the new buffer must match. + if (source_buffer->size() != dest_buffer->size() && + dest_buffer->size() != 0) { + return errors::InvalidArgument( + "Source buffer at index ", source_index.ToString(), + " does not match the size of destination buffer at index ", + dest_index.ToString(), ": ", source_buffer->size(), " vs ", + dest_buffer->size()); + } + *buffers_.mutable_element(dest_index) = source_buffer; + source_buffer->Ref(); + dest_buffer->Unref(); + return Status::OK(); +} + xla::ShapeTree -XRTTupleAllocation::ToDeviceMemoryTree(bool release) { +XRTTupleAllocation::ToDeviceMemoryTree( + const std::function& release_checker) { xla::ShapeTree shaped_tree(on_device_shape()); for (const auto& buffer : buffers_) { - if (!release) { + if (!release_checker(buffer.first)) { *shaped_tree.mutable_element(buffer.first) = buffer.second->allocation(); } else { *shaped_tree.mutable_element(buffer.first) = xla::OwningDeviceMemory( diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 3664c0cd4e6ad26945ae1012208fdb006164a066..6519da30d02e41da5a862cadd2133bd8dd8b42d7 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ #define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ +#include #include #include #include @@ -58,7 +59,14 @@ class XRTBufferAllocation : public core::RefCounted { // freed when the reference count drops to zero. void DiscardAllocation(); + // Returns the expected size of the allocation. Since DiscardAllocation() will + // set allocation_ to {null,0}, and since later we might want to replace the + // discarded buffer with a new one, we need to be able to verify the size + // compatibility. + uint64 size() const { return size_; } + private: + uint64 size_ = 0; se::DeviceMemoryBase allocation_; int device_ordinal_; xla::DeviceMemoryAllocator* allocator_; @@ -80,7 +88,7 @@ class XRTTupleAllocation : public ResourceBase { // Allocates new device memory buffers sufficient to store literal, transfers // literal to that memory, and returns a XRTTupleAllocation handle to the // allocated buffers. - static Status CreateAndTransfer(const xla::Literal& literal, + static Status CreateAndTransfer(const xla::LiteralBase& literal, xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation); @@ -129,13 +137,17 @@ class XRTTupleAllocation : public ResourceBase { // Deletes the reference in the rm to an allocation interned under key. static Status DeleteFromResourceManager(ResourceMgr* rm, int64 key); + // Releases all the device memory allocated by XRT within the resource + // manager. + static Status ReleaseAllAllocations(ResourceMgr* rm); + // Adds the allocation to a ResourceMgr and returns the key that will be used // to retrieve it. Transfers a reference on *this to rm. Status Intern(ResourceMgr* rm, int64* key); // Copies the allocation from device to host and returns it in literal. Status ToLiteral(xla::Backend* backend, int device_ordinal, - xla::Literal* literal); + xla::MutableLiteralBase* literal); // Write a new literal value to the allocation. Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); @@ -164,11 +176,20 @@ class XRTTupleAllocation : public ResourceBase { // the same shape as on_host_shape. xla::ShapedBuffer ToShapedBuffer(); - // Returns the device memory tree of this allocation. If 'release' is set, the - // ownership of the device memory is transferred to the result. - xla::ShapeTree ToDeviceMemoryTree(bool release); + // Aliases the source buffer at source_index into the current tuple allocation + // dest_index. + Status AliasBufferFrom(const XRTTupleAllocation& source, + const xla::ShapeIndex& source_index, + const xla::ShapeIndex& dest_index); + + // Returns the device memory tree of this allocation. If the release_checker + // function returns true for a given index, the ownership of the device memory + // at that index is transferred to the result. Every attempt to read the value + // at that index will fail. + xla::ShapeTree ToDeviceMemoryTree( + const std::function& release_checker); - string DebugString() override { return "XLA allocation handle"; } + string DebugString() const override { return "XLA allocation handle"; } private: // Creates a new handle with (tuple) shape. diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 832db0f4ab46911e067d17b4a125706c276cf798..0173b8bb064c7b2fb8a0df018204515b24cfa718 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -63,7 +63,6 @@ py_library( "//tensorflow/contrib/libsvm", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", - "//tensorflow/contrib/lite/python:lite", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/losses:metric_learning_py", @@ -197,7 +196,7 @@ cc_library( "//tensorflow/contrib/kinesis:dataset_kernels", ], }) + if_not_windows([ - "//tensorflow/contrib/tensorrt:trt_engine_op_kernel", + "//tensorflow/contrib/tensorrt:trt_op_kernels", ]), ) @@ -219,7 +218,6 @@ cc_library( "//tensorflow/contrib/tensor_forest:stats_ops_op_lib", "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/text:all_ops", - "//tensorflow/contrib/tpu:all_ops", ] + select({ "//tensorflow:android": [], "//tensorflow:ios": [], @@ -239,7 +237,7 @@ cc_library( "//tensorflow/contrib/kinesis:dataset_ops_op_lib", ], }) + if_not_windows([ - "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib", + "//tensorflow/compiler/tf2tensorrt:trt_op_libs", ]) + select({ "//tensorflow:android": [], "//tensorflow:ios": [], diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 4f1a2a5693235183c8f486817b82c8c81fa389ec..48d5296c71cbdb470fa405b30547a32b7022f29b 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -20,13 +20,14 @@ from __future__ import division from __future__ import print_function import os +import platform # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import autograph from tensorflow.contrib import batching from tensorflow.contrib import bayesflow from tensorflow.contrib import checkpoint -if os.name != "nt": +if os.name != "nt" and platform.machine() != "s390x": from tensorflow.contrib import cloud from tensorflow.contrib import cluster_resolver from tensorflow.contrib import coder @@ -91,7 +92,6 @@ from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager -from tensorflow.contrib.lite.python import lite from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2 from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field from tensorflow.contrib.recurrent.python import recurrent_api as recurrent @@ -103,6 +103,8 @@ from tensorflow.python.util.lazy_loader import LazyLoader ffmpeg = LazyLoader("ffmpeg", globals(), "tensorflow.contrib.ffmpeg") del os +del platform + del LazyLoader del absolute_import diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index f0b1c92cf7e4b760381da38febd9682ce2a4f27c..5608e7ddafa25757484d8c845c8c84a5691e143c 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -73,8 +73,7 @@ cc_binary( "-z defs", "-s", "-Wl,--gc-sections", - "-Wl,--version-script", # This line must be directly followed by LINKER_SCRIPT. - "$(location {})".format(LINKER_SCRIPT), + "-Wl,--version-script,$(location {})".format(LINKER_SCRIPT), ]), linkshared = 1, linkstatic = 1, diff --git a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb index 44532cb078f9bd1578172f8a7d8a4b55cd21a7cb..831c613f2c8c9a4fcc2cb9d313077fe79ee96fd7 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb @@ -186,8 +186,8 @@ "\n", " def __init__(self):\n", " super(RnnColorbot, self).__init__()\n", - " self.lower_cell = tf.contrib.rnn.LSTMBlockCell(256)\n", - " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n", + " self.lower_cell = tf.contrib.rnn.LSTMBlockCell(256, dtype=tf.float32)\n", + " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128, dtype=tf.float32)\n", " self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n", "\n", " def _rnn_layer(self, chars, cell, batch_size, training):\n", @@ -241,7 +241,7 @@ " seq = self._rnn_layer(seq, self.upper_cell, batch_size, training)\n", "\n", " # Grab just the end-of-sequence from each output.\n", - " indices = (length - 1, range(batch_size))\n", + " indices = (length - 1, list(range(batch_size)))\n", " indices = tf.stack(indices, 1)\n", " sequence_ends = tf.gather_nd(seq, indices)\n", " return self.relu_layer(sequence_ends)\n", diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index 648f3ebb05646a66144bcb118347cbc391909409..5174afe0a63d37e3ea3e19ac9bab644d1d83ecf1 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -37,6 +37,7 @@ py_library( cc_library( name = "batch_ops_kernels", deps = [ + "//tensorflow/core:batch_ops_op_lib", "//tensorflow/core/kernels:batch_kernels", ], alwayslink = 1, diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc index 6138d7912601344ef7422fd50fb35c8401fd2e63..f0637595db08cbeb3b3ee0c94c5399df4c8c83e6 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/core/lib/core/threadpool.h" namespace tensorflow { - namespace { class BigtableClientOp : public OpKernel { @@ -341,8 +340,8 @@ class ToBigtableOp : public AsyncOpKernel { } template - Status ParseScalarArgument(OpKernelContext* ctx, - const StringPiece& argument_name, T* output) { + Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name, + T* output) { const Tensor* argument_t; TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); if (!TensorShapeUtils::IsScalar(argument_t->shape())) { @@ -360,5 +359,4 @@ REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU), } // namespace } // namespace data - } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h index 4652021fecabfa11fa6a8754dc884d89e151b590..e3b4535bac4a01a1277290e0d1ea6d3c7613731c 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h @@ -42,7 +42,7 @@ class BigtableClientResource : public ResourceBase { return client_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("BigtableClientResource(project_id: ", project_id_, ", instance_id: ", instance_id_, ")"); } @@ -67,7 +67,7 @@ class BigtableTableResource : public ResourceBase { ::google::cloud::bigtable::noex::Table& table() { return table_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat( "BigtableTableResource(client: ", client_->DebugString(), ", table: ", table_name_, ")"); diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc index e95dc577184f7e81d942755b41065f52131ce9f6..d9fce6e09f47ab05074f0b4c03dd8e672ed3d2ce 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" -#include "google/bigtable/v2/data.pb.h" +#include "external/com_github_googleapis_googleapis/google/bigtable/v2/data.pb.h" #include "google/protobuf/wrappers.pb.h" #include "re2/re2.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -335,6 +335,17 @@ grpc::Status BigtableTestClient::ReadModifyWriteRow( return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "ReadModifyWriteRow not implemented."); } +std::unique_ptr> +BigtableTestClient::AsyncReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + grpc::CompletionQueue* cq) { + LOG(WARNING) << "Call to AsyncReadModifyWriteRow:" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + std::unique_ptr< grpc::ClientReaderInterface> BigtableTestClient::ReadRows( @@ -399,6 +410,28 @@ BigtableTestClient::AsyncMutateRows( return nullptr; } +std::unique_ptr> +BigtableTestClient::AsyncCheckAndMutateRow( + grpc::ClientContext* context, + const google::bigtable::v2::CheckAndMutateRowRequest& request, + grpc::CompletionQueue* cq) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + +std::unique_ptr< + grpc::ClientAsyncReaderInterface> +BigtableTestClient::AsyncReadRows( + grpc::ClientContext* context, + const google::bigtable::v2::ReadRowsRequest& request, + grpc::CompletionQueue* cq, void* tag) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + std::shared_ptr BigtableTestClient::Channel() { LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " "cause a crash!"; diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h index c4a1f06bc504c3565c7bb09b42e48e7fbddb9cc6..63d59b32dd17a2f58d3413932b69f4d704c84e48 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -46,6 +46,13 @@ class BigtableTestClient : public ::google::cloud::bigtable::DataClient { google::bigtable::v2::ReadModifyWriteRowRequest const& request, google::bigtable::v2::ReadModifyWriteRowResponse* response) override; + std::unique_ptr> + AsyncReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + grpc::CompletionQueue* cq) override; + std::unique_ptr< grpc::ClientReaderInterface> ReadRows(grpc::ClientContext* context, @@ -80,6 +87,19 @@ class BigtableTestClient : public ::google::cloud::bigtable::DataClient { const ::google::bigtable::v2::MutateRowsRequest& request, ::grpc::CompletionQueue* cq, void* tag) override; + std::unique_ptr> + AsyncCheckAndMutateRow( + grpc::ClientContext* context, + const google::bigtable::v2::CheckAndMutateRowRequest& request, + grpc::CompletionQueue* cq) override; + + std::unique_ptr< + grpc::ClientAsyncReaderInterface> + AsyncReadRows(grpc::ClientContext* context, + const google::bigtable::v2::ReadRowsRequest& request, + grpc::CompletionQueue* cq, void* tag) override; + std::shared_ptr Channel() override; private: diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc index 416b719e30aa5f2504449d151a48e95c9105c68b..39c2a2e775d5d5287b137bf33eef66251738e6d3 100644 --- a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc +++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc @@ -59,7 +59,7 @@ REGISTER_OP("BigtablePrefixKeyDataset") .Input("table: resource") .Input("prefix: string") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); @@ -68,14 +68,14 @@ REGISTER_OP("BigtableRangeKeyDataset") .Input("start_key: string") .Input("end_key: string") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("BigtableSampleKeysDataset") .Input("table: resource") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); @@ -85,7 +85,7 @@ REGISTER_OP("BigtableSampleKeyPairsDataset") .Input("start_key: string") .Input("end_key: string") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); @@ -100,7 +100,7 @@ REGISTER_OP("BigtableScanDataset") .Input("columns: string") .Input("probability: float") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index b6cdc7aab0320fe5f457288ada03a46e18a694cc..fa64055dfd65a134afdf46cebccb7f7d96106502 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -489,7 +489,7 @@ class BigtableTable(object): "len(dataset.output_types))") return gen_bigtable_ops.dataset_to_bigtable( self._resource, - dataset._as_variant_tensor(), # pylint: disable=protected-access + dataset._variant_tensor, # pylint: disable=protected-access column_families, columns, timestamp) @@ -582,13 +582,14 @@ class _BigtableKeyDataset(dataset_ops.DatasetSource): """_BigtableKeyDataset is an abstract class representing the keys of a table. """ - def __init__(self, table): + def __init__(self, table, variant_tensor): """Constructs a _BigtableKeyDataset. Args: table: a Bigtable class. + variant_tensor: DT_VARIANT representation of the dataset. """ - super(_BigtableKeyDataset, self).__init__() + super(_BigtableKeyDataset, self).__init__(variant_tensor) self._table = table @property @@ -601,13 +602,11 @@ class _BigtablePrefixKeyDataset(_BigtableKeyDataset): """ def __init__(self, table, prefix): - super(_BigtablePrefixKeyDataset, self).__init__(table) self._prefix = prefix - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_prefix_key_dataset( - table=self._table._resource, # pylint: disable=protected-access + variant_tensor = gen_bigtable_ops.bigtable_prefix_key_dataset( + table=table._resource, # pylint: disable=protected-access prefix=self._prefix) + super(_BigtablePrefixKeyDataset, self).__init__(table, variant_tensor) class _BigtableRangeKeyDataset(_BigtableKeyDataset): @@ -615,15 +614,13 @@ class _BigtableRangeKeyDataset(_BigtableKeyDataset): """ def __init__(self, table, start, end): - super(_BigtableRangeKeyDataset, self).__init__(table) self._start = start self._end = end - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_range_key_dataset( - table=self._table._resource, # pylint: disable=protected-access + variant_tensor = gen_bigtable_ops.bigtable_range_key_dataset( + table=table._resource, # pylint: disable=protected-access start_key=self._start, end_key=self._end) + super(_BigtableRangeKeyDataset, self).__init__(table, variant_tensor) class _BigtableSampleKeysDataset(_BigtableKeyDataset): @@ -633,11 +630,9 @@ class _BigtableSampleKeysDataset(_BigtableKeyDataset): # TODO(saeta): Expose the data size offsets into the keys. def __init__(self, table): - super(_BigtableSampleKeysDataset, self).__init__(table) - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_sample_keys_dataset( - table=self._table._resource) # pylint: disable=protected-access + variant_tensor = gen_bigtable_ops.bigtable_sample_keys_dataset( + table=table._resource) # pylint: disable=protected-access + super(_BigtableSampleKeysDataset, self).__init__(table, variant_tensor) class _BigtableLookupDataset(dataset_ops.DatasetSource): @@ -651,20 +646,18 @@ class _BigtableLookupDataset(dataset_ops.DatasetSource): self._normalized = normalized self._column_families = [i[0] for i in normalized] self._columns = [i[1] for i in normalized] + variant_tensor = gen_bigtable_ops.bigtable_lookup_dataset( + keys_dataset=self._dataset._variant_tensor, # pylint: disable=protected-access + table=self._table._resource, # pylint: disable=protected-access + column_families=self._column_families, + columns=self._columns) + super(_BigtableLookupDataset, self).__init__(variant_tensor) @property def _element_structure(self): return structure.NestedStructure(tuple( [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_bigtable_ops.bigtable_lookup_dataset( - keys_dataset=self._dataset._as_variant_tensor(), - table=self._table._resource, - column_families=self._column_families, - columns=self._columns) - class _BigtableScanDataset(dataset_ops.DatasetSource): """_BigtableScanDataset represents a dataset that retrieves keys and values. @@ -679,14 +672,7 @@ class _BigtableScanDataset(dataset_ops.DatasetSource): self._columns = [i[1] for i in normalized] self._probability = probability self._num_outputs = len(normalized) + 1 # 1 for row key - - @property - def _element_structure(self): - return structure.NestedStructure(tuple( - [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_scan_dataset( + variant_tensor = gen_bigtable_ops.bigtable_scan_dataset( table=self._table._resource, # pylint: disable=protected-access prefix=self._prefix, start_key=self._start, @@ -694,6 +680,13 @@ class _BigtableScanDataset(dataset_ops.DatasetSource): column_families=self._column_families, columns=self._columns, probability=self._probability) + super(_BigtableScanDataset, self).__init__(variant_tensor) + + @property + def _element_structure(self): + return structure.NestedStructure( + tuple( + [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource): @@ -705,17 +698,15 @@ class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource): self._prefix = prefix self._start = start self._end = end + variant_tensor = gen_bigtable_ops.bigtable_sample_key_pairs_dataset( + table=self._table._resource, # pylint: disable=protected-access + prefix=self._prefix, + start_key=self._start, + end_key=self._end) + super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor) @property def _element_structure(self): return structure.NestedStructure( (structure.TensorStructure(dtypes.string, []), structure.TensorStructure(dtypes.string, []))) - - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_bigtable_ops.bigtable_sample_key_pairs_dataset( - table=self._table._resource, - prefix=self._prefix, - start_key=self._start, - end_key=self._end) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index d3b23d949ee2c7674c3918d39e8b71d76eefcfec..64e4c4560ba3a1b177db12a09997ff7afe8775a3 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -193,8 +193,9 @@ py_test( py_test( name = "estimator_test", - size = "large", + size = "medium", srcs = ["estimator_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = [ "no_gpu", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index a178820841c4c8bcb7f5742babdb6d0f4825de31..5ffbb9067081d7440ab5e11290697b822051bee5 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -84,12 +84,10 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is - [batch_size, num_trees]. - For example, - result_iter = classifier.predict(...) - for result_dict in result_iter: - # access leaf index list by result_dict["leaf_index"] - # which contains one leaf index per tree + [batch_size, num_trees]. For example, result_iter = + classifier.predict(...) + for result_dict in result_iter: # access leaf index list by + result_dict["leaf_index"] # which contains one leaf index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -179,8 +177,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): `[batch_size, label_dimension]`). num_trees: An int, number of trees to build. feature_columns: A list of feature columns. - label_name: String, name of the key in label dict. Can be null if label - is a tensor (single headed models). + label_name: String, name of the key in label dict. Can be null if label is + a tensor (single headed models). weight_column_name: Name of the column for weights, or None if not weighted. model_dir: Directory for model exports, etc. @@ -195,11 +193,11 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): opposed to contrib) version of tensorflow. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -286,11 +284,11 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): opposed to contrib) version of tensorflow. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -353,10 +351,9 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): layer. It can also be a function that computes the number of examples based on the depth of the layer that's being built. head: `Head` instance. - ranking_model_pair_keys: Keys to distinguish between features - for left and right part of the training pairs for ranking. For example, - for an Example with features "a.f1" and "b.f1", the keys would be - ("a", "b"). + ranking_model_pair_keys: Keys to distinguish between features for left and + right part of the training pairs for ranking. For example, for an + Example with features "a.f1" and "b.f1", the keys would be ("a", "b"). num_trees: An int, number of trees to build. feature_columns: A list of feature columns. weight_column_name: Name of the column for weights, or None if not @@ -376,12 +373,10 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is - [batch_size, num_trees]. - For example, - result_iter = classifier.predict(...) - for result_dict in result_iter: - # access leaf index list by result_dict["leaf_index"] - # which contains one leaf index per tree + [batch_size, num_trees]. For example, result_iter = + classifier.predict(...) + for result_dict in result_iter: # access leaf index list by + result_dict["leaf_index"] # which contains one leaf index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -417,12 +412,12 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): config=config, feature_engineering_fn=feature_engineering_fn) + # When using this estimator, make sure to regularize the hessian (at least l2, # min_node_weight)! # TODO(nponomareva): extend to take multiple quantiles in one go. class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): - """An estimator that does quantile regression and returns quantile estimates. - """ + """An estimator that does quantile regression and returns quantile estimates.""" def __init__(self, learner_config, @@ -449,8 +444,8 @@ class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): layer. It can also be a function that computes the number of examples based on the depth of the layer that's being built. quantiles: a list of quantiles for the loss, each between 0 and 1. - label_dimension: Dimension of regression label. This is the size - of the last dimension of the labels `Tensor` (typically, this has shape + label_dimension: Dimension of regression label. This is the size of the + last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). When label_dimension>1, it is recommended to use multiclass strategy diagonal hessian or full hessian. num_trees: An int, number of trees to build. @@ -469,11 +464,11 @@ class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): opposed to contrib) version of tensorflow. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -519,6 +514,7 @@ class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): config=config, feature_engineering_fn=feature_engineering_fn) + # ================== New Estimator interface=================================== # The estimators below use new core Estimator interface and must be used with # new feature columns and heads. @@ -534,10 +530,8 @@ def core_multiclass_head( def loss_fn(labels, logits): result = losses.per_example_maxent_loss( - labels=labels, - logits=logits, - weights=weight_column, - num_classes=n_classes) + # Don't pass the weights: head already multiplies by them. + labels=labels, logits=logits, weights=None, num_classes=n_classes) return result[0] # pylint:disable=protected-access @@ -564,7 +558,8 @@ def core_quantile_regression_head( result = losses.per_example_quantile_regression_loss( labels=labels, predictions=logits, - weights=weight_column, + # Don't pass the weights: head already multiplies by them. + weights=None, quantile=quantiles) return result[0] @@ -623,11 +618,11 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): the bias. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree num_quantiles: Number of quantiles to build for numeric feature values. """ @@ -685,10 +680,9 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): layer. It can also be a function that computes the number of examples based on the depth of the layer that's being built. head: `Head` instance. - ranking_model_pair_keys: Keys to distinguish between features - for left and right part of the training pairs for ranking. For example, - for an Example with features "a.f1" and "b.f1", the keys would be - ("a", "b"). + ranking_model_pair_keys: Keys to distinguish between features for left and + right part of the training pairs for ranking. For example, for an + Example with features "a.f1" and "b.f1", the keys would be ("a", "b"). num_trees: An int, number of trees to build. feature_columns: A list of feature columns. weight_column_name: Name of the column for weights, or None if not @@ -703,12 +697,10 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is - [batch_size, num_trees]. - For example, - result_iter = classifier.predict(...) - for result_dict in result_iter: - # access leaf index list by result_dict["leaf_index"] - # which contains one leaf index per tree + [batch_size, num_trees]. For example, result_iter = + classifier.predict(...) + for result_dict in result_iter: # access leaf index list by + result_dict["leaf_index"] # which contains one leaf index per tree num_quantiles: Number of quantiles to build for numeric feature values. Raises: @@ -748,8 +740,7 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): # TODO(nponomareva): extend to take multiple quantiles in one go. class CoreGradientBoostedDecisionTreeQuantileRegressor( core_estimator.Estimator): - """An estimator that does quantile regression and returns quantile estimates. - """ + """An estimator that does quantile regression and returns quantile estimates.""" def __init__(self, learner_config, @@ -775,8 +766,8 @@ class CoreGradientBoostedDecisionTreeQuantileRegressor( layer. It can also be a function that computes the number of examples based on the depth of the layer that's being built. quantiles: a list of quantiles for the loss, each between 0 and 1. - label_dimension: Dimension of regression label. This is the size - of the last dimension of the labels `Tensor` (typically, this has shape + label_dimension: Dimension of regression label. This is the size of the + last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). When label_dimension>1, it is recommended to use multiclass strategy diagonal hessian or full hessian. num_trees: An int, number of trees to build. @@ -795,11 +786,11 @@ class CoreGradientBoostedDecisionTreeQuantileRegressor( the bias. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree num_quantiles: Number of quantiles to build for numeric feature values. """ if len(quantiles) > 1: @@ -814,7 +805,9 @@ class CoreGradientBoostedDecisionTreeQuantileRegressor( params={ 'head': core_quantile_regression_head( - quantiles[0], label_dimension=label_dimension), + quantiles[0], + label_dimension=label_dimension, + weight_column=weight_column_name), 'feature_columns': feature_columns, 'learner_config': diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index ee052ac60387d8f993e4942dd7dff39e191dd3a4..5a8b2ba9caf0a9813cb5b3409b8a0dc3de0a45d7 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -399,8 +399,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): def testQuantileRegression(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.max_tree_depth = 6 + learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -413,7 +413,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): model_upper = estimator.GradientBoostedDecisionTreeQuantileRegressor( quantiles=[0.95], learner_config=learner_config, - num_trees=100, + num_trees=12, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) @@ -428,31 +428,12 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_below_upper >= 0.92) self.assertTrue(frac_below_upper <= 0.98) - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() - model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.fit(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["scores"]) - - frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower >= 0.92) - self.assertTrue(frac_above_lower <= 0.98) - # Multi-dimensional quantile regression. def testQuantileRegressionMultiDimLabel(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.max_tree_depth = 6 + learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -467,7 +448,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): quantiles=[0.95], learner_config=learner_config, label_dimension=2, - num_trees=100, + num_trees=18, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) @@ -487,37 +468,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_below_upper_0 <= 0.98) self.assertTrue(frac_below_upper_1 >= 0.92) self.assertTrue(frac_below_upper_1 <= 0.98) - self.assertTrue(frac_both_below_upper >= 0.92) - self.assertTrue(frac_both_below_upper <= 0.98) - - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( - two_dimension=True) - model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - label_dimension=2, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.fit(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["scores"]) - - count_above_lower = np.count_nonzero(lower < y, axis=0) - count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) - frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) - frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) - frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower_0 >= 0.92) - self.assertTrue(frac_above_lower_0 <= 0.98) - self.assertTrue(frac_above_lower_1 >= 0.92) - self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.92) - self.assertTrue(frac_both_above_lower <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.91) + self.assertTrue(frac_both_below_upper <= 0.99) class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): @@ -712,11 +664,12 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): est.evaluate(input_fn=input_fn, steps=1) est.predict(input_fn=input_fn) - # One dimensional quantile regression. - def testQuantileRegression(self): + # Quantile regression in core is the same as in non core estimator, so we + # just check that it does not fail. + def testQuantileRegressionDoesNotThroughException(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 + learner_config.constraints.max_tree_depth = 1 learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -731,112 +684,12 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( quantiles=[0.95], learner_config=learner_config, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_upper.train(input_fn=train_input_fn, steps=1000) - result_iter = model_upper.predict(input_fn=test_input_fn) - upper = [] - for prediction_dict in result_iter: - upper.append(prediction_dict["predictions"]) - - frac_below_upper = round(1. * np.count_nonzero(upper > y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_below_upper >= 0.92) - self.assertTrue(frac_below_upper <= 0.98) - - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() - model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.train(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["predictions"]) - - frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower >= 0.92) - self.assertTrue(frac_above_lower <= 0.98) - - # Multi-dimensional quantile regression. - def testQuantileRegressionMultiDimLabel(self): - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE - learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.tree_complexity = ( - 1.0 / _QUANTILE_REGRESSION_SIZE) - - train_input_fn, test_input_fn, y = _quantile_regression_input_fns( - two_dimension=True) - y = y.reshape(_QUANTILE_REGRESSION_SIZE, 2) - - # 95% percentile. - model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.95], - learner_config=learner_config, - num_trees=100, - label_dimension=2, + num_trees=1, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) model_upper.train(input_fn=train_input_fn, steps=1000) result_iter = model_upper.predict(input_fn=test_input_fn) - upper = [] - for prediction_dict in result_iter: - upper.append(prediction_dict["predictions"]) - - count_below_upper = np.count_nonzero(upper > y, axis=0) - count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1)) - frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3) - frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3) - frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3) - # +/- 3% - self.assertTrue(frac_below_upper_0 >= 0.92) - self.assertTrue(frac_below_upper_0 <= 0.98) - self.assertTrue(frac_below_upper_1 >= 0.92) - self.assertTrue(frac_below_upper_1 <= 0.98) - self.assertTrue(frac_both_below_upper >= 0.92) - self.assertTrue(frac_both_below_upper <= 0.98) - - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( - two_dimension=True) - model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - label_dimension=2, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.train(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["predictions"]) - - count_above_lower = np.count_nonzero(lower < y, axis=0) - count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) - frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) - frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) - frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower_0 >= 0.92) - self.assertTrue(frac_above_lower_0 <= 0.98) - self.assertTrue(frac_above_lower_1 >= 0.92) - self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.92) - self.assertTrue(frac_both_above_lower <= 0.98) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index a6e422847d3914188bca9e6dff797ba1ffb06749..eecf3c5aeb6c6785cae3fd5808954a73db6190d6 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -25,6 +25,7 @@ from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import ops from tensorflow.python.ops import state_ops from tensorflow.python.training import training_util @@ -88,6 +89,12 @@ def model_builder(features, if config is None: raise ValueError("Missing estimator RunConfig.") + if config.session_config is not None: + session_config = config.session_config + session_config.allow_soft_placement = True + else: + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + config = config.replace(session_config=session_config) center_bias = params["center_bias"] diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 6d78e27e8f69ea289b686af8402bd91967f997f4..65276242abaf96de8b1936365278b18f8bba93a9 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -538,7 +538,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel { partition_boundaries[non_empty_partitions[root_idx]]; float best_gain = std::numeric_limits::lowest(); - int32 best_dimension_idx = 0; bool default_right = false; int32 best_element_idx = 0; @@ -571,7 +570,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel { // Iterate through dimensions. for (int j = 0; j < dimension_boundaries.size() - 1; ++j) { const DimensionBoundary& dimension_and_start = dimension_boundaries[j]; - const int32 dimension_id = dimension_and_start.dimension_id; int start_index = dimension_and_start.start_index; // Even for the last dimension, we always have additional dummy @@ -630,7 +628,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel { best_right_node_stats = right_stats_default_left; best_element_idx = element_idx; default_right = false; - best_dimension_idx = dimension_id; } } // Consider calculating the default direction only when there were @@ -648,7 +645,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel { best_right_node_stats = right_stats_default_right; best_element_idx = element_idx; default_right = true; - best_dimension_idx = dimension_id; } } } diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc index e446c411a8d5075563b8f8b912b29df310e16c8c..6faf6963011b698a3b233329d87471da7608e44a 100644 --- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc @@ -96,7 +96,7 @@ class StatsAccumulatorResource : public boosted_trees::StampedResource { TensorShapeUtils::IsScalar(hessian_shape)); } - string DebugString() override { + string DebugString() const override { return strings::StrCat("StatsAccumulatorResource[size=", values_.size(), "]"); } diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py index 42d69645acaae063fcd46bd1f6c819ccb68f48bd..aa3f24f08a0f762507df83def72e7d595265221f 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py @@ -227,7 +227,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase): tree_ensemble_config=tree_ensemble_config.SerializeToString(), name="restore_tree") resources.initialize_resources(resources.shared_resources()).run() - variables.initialize_all_variables().run() + variables.global_variables_initializer().run() my_saver = saver.Saver() # Add the second tree and replace the ensemble of the handle. diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index fca22c71a83459cb290eaebcf107cf1c14c222b7..ad6ff0a861af896ef0dd254bd47752d76378d63a 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -33,7 +33,7 @@ from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensem from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking ops.NotDifferentiable("TreeEnsembleVariable") ops.NotDifferentiable("TreeEnsembleSerialize") @@ -62,8 +62,8 @@ class TreeEnsembleVariableSavable(saver.BaseSaverBuilder.SaveableObject): saver.BaseSaverBuilder.SaveSpec(ensemble_config, slice_spec, name + "_config"), ] - super(TreeEnsembleVariableSavable, - self).__init__(tree_ensemble_handle, specs, name) + super(TreeEnsembleVariableSavable, self).__init__(tree_ensemble_handle, + specs, name) self._tree_ensemble_handle = tree_ensemble_handle self._create_op = create_op @@ -115,7 +115,7 @@ class TreeEnsembleVariable(tracking.TrackableResource): def _gather_saveables_for_checkpoint(self): return { - "tree_ensemble_variable": + self.resource_handle.op.name + "/tree_ensemble_variable": functools.partial( TreeEnsembleVariableSavable, tree_ensemble_handle=self.resource_handle, @@ -131,8 +131,8 @@ def tree_ensemble_variable(stamp_token, Args: stamp_token: The initial stamp token value for the ensemble resource. - tree_ensemble_config: A `Tensor` of type `string`. - Serialized proto of the tree ensemble. + tree_ensemble_config: A `Tensor` of type `string`. Serialized proto of the + tree ensemble. name: A name for the ensemble variable. container: An optional `string`. Defaults to `""`. diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 0c319cc9bd1f720eb404a9da05227c5807ec874f..aff7105e94729942efc6e3e9d3ae23b733e8f5ed 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -33,7 +33,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import resources from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") diff --git a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py index ad1191d41236e71008bff8c8a7fbd42c16e3f9c5..2a0a206d97bbf01ac382531df31a66d429842bbb 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py @@ -26,7 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import resources from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 9fdc2fc0c2c7b85502f7a3f9ae7c85cf05d5916c..e78ec476ab3b43e5eb56a2502008bb8020ae97e0 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -566,9 +566,10 @@ class GradientBoostedDecisionTreeModel(object): # Determine if ensemble is colocated with the inputs. if self._ensemble_handle.device != input_deps[0].device: # Create a local ensemble and get its local stamp. - with ops.name_scope("local_ensemble", "TreeEnsembleVariable") as name: + with ops.name_scope("local_ensemble", "TreeEnsembleVariable"): local_ensemble_handle = ( - gen_model_ops.decision_tree_ensemble_resource_handle_op(name=name)) + gen_model_ops.decision_tree_ensemble_resource_handle_op( + self._ensemble_handle.op.name + "/local_ensemble")) create_op = gen_model_ops.create_tree_ensemble_variable( local_ensemble_handle, stamp_token=-1, tree_ensemble_config="") with ops.control_dependencies([create_op]): @@ -614,13 +615,19 @@ class GradientBoostedDecisionTreeModel(object): predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) return constant_op.constant(-1, dtype=dtypes.int32) - def update_stats(self, loss, predictions_dict): + def update_stats(self, loss, predictions_dict, gradients=None, hessians=None): """Update the accumulators with stats from this batch. Args: loss: A scalar tensor representing average loss of examples. predictions_dict: Dictionary of Rank 2 `Tensor` representing information about predictions per example. + gradients: A tensor with the gradients with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. + hessians: A tensor with the hessians with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. Returns: Three values: @@ -642,13 +649,14 @@ class GradientBoostedDecisionTreeModel(object): predictions = predictions_dict[PREDICTIONS] partition_ids = predictions_dict[PARTITION_IDS] ensemble_stamp = predictions_dict[ENSEMBLE_STAMP] - gradients = gradients_impl.gradients( - loss, - predictions, - name="Gradients", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if gradients is None: + gradients = gradients_impl.gradients( + loss, + predictions, + name="Gradients", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy class_id = self._get_class_id(predictions_dict) @@ -657,17 +665,20 @@ class GradientBoostedDecisionTreeModel(object): # We build one vs rest trees. if self._logits_dimension == 1: # We have only 1 score, gradients is of shape [batch, 1]. - hessians = gradients_impl.gradients( - gradients, - predictions, - name="Hessian", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if hessians is None: + hessians = gradients_impl.gradients( + gradients, + predictions, + name="Hessian", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] squeezed_gradients = array_ops.squeeze(gradients, axis=[1]) squeezed_hessians = array_ops.squeeze(hessians, axis=[1]) else: + if hessians is not None: + raise ValueError("Providing hessians is not yet supported here.") hessian_list = self._diagonal_hessian(gradients, predictions) # Assemble hessian list into a tensor. hessians = array_ops.stack(hessian_list, axis=1) @@ -678,6 +689,8 @@ class GradientBoostedDecisionTreeModel(object): squeezed_hessians = array_ops.squeeze( _get_column_by_index(hessians, class_id)) else: + if hessians is not None: + raise ValueError("Providing hessians is not yet supported here.") # Other multiclass strategies. if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN: hessian_list = self._full_hessian(gradients, predictions) @@ -835,9 +848,9 @@ class GradientBoostedDecisionTreeModel(object): stats_update_ops.append( control_flow_ops.cond( continue_centering, - self._make_update_bias_stats_fn( - ensemble_stamp, predictions, gradients, - bias_stats_accumulator), control_flow_ops.no_op)) + self._make_update_bias_stats_fn(ensemble_stamp, predictions, + gradients, bias_stats_accumulator, + hessians), control_flow_ops.no_op)) # Update handler stats. handler_reads = collections.OrderedDict() @@ -1162,7 +1175,8 @@ class GradientBoostedDecisionTreeModel(object): def get_max_tree_depth(self): return self._max_tree_depth - def train(self, loss, predictions_dict, labels): + def train(self, loss, predictions_dict, labels, gradients=None, + hessians=None): """Updates the accumalator stats and grows the ensemble. Args: @@ -1171,6 +1185,12 @@ class GradientBoostedDecisionTreeModel(object): about predictions per example. labels: Rank 2 `Tensor` representing labels per example. Has no effect on the training and is only kept for backward compatibility. + gradients: A tensor with the gradients with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. + hessians: A tensor with the hessians with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. Returns: An op that adds a new tree to the ensemble. @@ -1179,7 +1199,8 @@ class GradientBoostedDecisionTreeModel(object): ValueError: if inputs are not valid. """ del labels # unused; kept for backward compatibility. - update_op, _, training_state = self.update_stats(loss, predictions_dict) + update_op, _, training_state = self.update_stats(loss, predictions_dict, + gradients, hessians) with ops.control_dependencies(update_op): return self.increment_step_counter_and_maybe_update_ensemble( predictions_dict, training_state) @@ -1271,21 +1292,28 @@ class GradientBoostedDecisionTreeModel(object): ps_ops=ps_ops, ps_strategy=ps_strategy) - def _make_update_bias_stats_fn(self, ensemble_stamp, predictions, gradients, - bias_stats_accumulator): + def _make_update_bias_stats_fn(self, + ensemble_stamp, + predictions, + gradients, + bias_stats_accumulator, + hessians=None): """A method to create the function which updates the bias stats.""" def _update_bias_stats(): """A method to update the bias stats.""" # Get reduced gradients and hessians. grads_sum = math_ops.reduce_sum(gradients, 0) - hess = gradients_impl.gradients( - grads_sum, - predictions, - name="Hessians", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if hessians is not None: + hess = hessians + else: + hess = gradients_impl.gradients( + grads_sum, + predictions, + name="Hessians", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] hess_sum = math_ops.reduce_sum(hess, 0) # Accumulate gradients and hessians. diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 92068e88a76cb8bfdd394c1093347a8fb8a63449..7e45d0b2cecefa4bdec77d6cf7cfca7dba04db9c 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -43,7 +43,7 @@ from tensorflow.python.platform import googletest def _squared_loss(label, unused_weights, predictions): """Unweighted loss implementation.""" loss = math_ops.reduce_sum( - math_ops.square(predictions - label), 1, keepdims=True) + math_ops.squared_difference(predictions, label), 1, keepdims=True) return loss diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index 220e981618b7c0bfb1e4e98c087d83b451b9b3cf..1ad40aca2880940c78d746674c7378ff0427c057 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -166,7 +166,7 @@ def per_example_squared_loss(labels, weights, predictions): update_op: An update operation to update the loss's internal state. """ unweighted_loss = math_ops.reduce_sum( - math_ops.square(predictions - labels), 1, keepdims=True) + math_ops.squared_difference(predictions, labels), 1, keepdims=True) return unweighted_loss * weights, control_flow_ops.no_op() diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h index 94aeb2c7bb48c6eddb6c7894f8bf6f1567470113..0fe57c0a4e8375cc7ec7aca9553bded87e238b33 100644 --- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -34,7 +34,7 @@ class DecisionTreeEnsembleResource : public StampedResource { protobuf::Arena::CreateMessage< boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_)) {} - string DebugString() override { + string DebugString() const override { return strings::StrCat("GTFlowDecisionTreeEnsemble[size=", decision_tree_ensemble_->trees_size(), "]"); } diff --git a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h index fdaaae7f472c8f564ab45a8366d3746cbf1158ee..574e3065e7f46049815897ef73e44d33f0d23f0f 100644 --- a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h @@ -43,7 +43,7 @@ class QuantileStreamResource : public StampedResource { set_stamp(stamp_token); } - string DebugString() override { return "QuantileStreamResource"; } + string DebugString() const override { return "QuantileStreamResource"; } tensorflow::mutex* mutex() { return &mu_; } diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 94b7f4f867655bf7fdf94e8488eeae7088c41622..7b3df962542a656af8052e9f2eae6e83744411f2 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -27,7 +27,7 @@ Managing dependencies: @@NoDependency @@split_dependency -Checkpointable data structures: +Trackable data structures: @@List @@Mapping @@UniqueNameTracker @@ -49,17 +49,16 @@ from tensorflow.contrib.checkpoint.python.python_state import NumpyState from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint -from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph +from tensorflow.core.protobuf.trackable_object_graph_pb2 import TrackableObjectGraph as CheckpointableObjectGraph from tensorflow.python.training.checkpoint_management import CheckpointManager -from tensorflow.python.training.checkpointable.base import CheckpointableBase -from tensorflow.python.training.checkpointable.data_structures import List -from tensorflow.python.training.checkpointable.data_structures import Mapping -from tensorflow.python.training.checkpointable.data_structures import NoDependency -from tensorflow.python.training.checkpointable.tracking import Checkpointable -from tensorflow.python.training.checkpointable.util import capture_dependencies -from tensorflow.python.training.checkpointable.util import list_objects -from tensorflow.python.training.checkpointable.util import object_metadata - +from tensorflow.python.training.tracking.base import Trackable as CheckpointableBase +from tensorflow.python.training.tracking.data_structures import List +from tensorflow.python.training.tracking.data_structures import Mapping +from tensorflow.python.training.tracking.data_structures import NoDependency +from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable +from tensorflow.python.training.tracking.util import capture_dependencies +from tensorflow.python.training.tracking.util import list_objects +from tensorflow.python.training.tracking.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(module_name=__name__) diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index ada41687261ab63286933d01da4e286173042e0c..cd9c94c9bd72d398d183d3f3d485ab48cb2fd617 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -2,7 +2,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_test") py_library( name = "checkpoint", @@ -12,7 +12,7 @@ py_library( ":python_state", ":split_dependency", ":visualize", - "//tensorflow/python/training/checkpointable:data_structures", + "//tensorflow/python/training/tracking:data_structures", ], ) @@ -22,22 +22,22 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:data_structures", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:data_structures", ], ) -py_test( +tf_py_test( name = "containers_test", srcs = ["containers_test.py"], - deps = [ + additional_deps = [ ":containers", + "@six_archive//:six", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:util", - "@six_archive//:six", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:util", ], ) @@ -47,24 +47,24 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/tracking:base", "//third_party/py/numpy", "@six_archive//:six", ], ) -py_test( +tf_py_test( name = "python_state_test", srcs = ["python_state_test.py"], - deps = [ + additional_deps = [ ":python_state", + "//third_party/py/numpy", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:session", "//tensorflow/python:variables", "//tensorflow/python/eager:test", - "//tensorflow/python/training/checkpointable:util", - "//third_party/py/numpy", + "//tensorflow/python/training/tracking:util", ], ) @@ -76,21 +76,21 @@ py_library( deps = [ "//tensorflow/python:control_flow_ops", "//tensorflow/python:training", - "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/tracking:base", ], ) -py_test( +tf_py_test( name = "split_dependency_test", srcs = ["split_dependency_test.py"], - deps = [ + additional_deps = [ ":split_dependency", "//tensorflow/python:array_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", "//tensorflow/python/eager:test", - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:util", ], ) @@ -101,15 +101,15 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:util", ], ) -py_test( +tf_py_test( name = "visualize_test", srcs = ["visualize_test.py"], - deps = [ + additional_deps = [ ":visualize", "//tensorflow/python:constant_op", "//tensorflow/python:resource_variable_ops", @@ -118,6 +118,7 @@ py_test( "//tensorflow/python/eager:test", "//tensorflow/python/keras:engine", "//tensorflow/python/keras:layers", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:util", ], + tags = ["no_oss"], # b/124472244 ) diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 5418e2605b724edb60878e250d2c50fcc6ff5633..a25d51980ea760dfb7f323497a397fbd94fd5f23 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -1,4 +1,4 @@ -"""Checkpointable data structures.""" +"""Trackable data structures.""" # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,12 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.training.checkpointable import base as checkpointable_lib -from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.tracking import base as trackable_lib +from tensorflow.python.training.tracking import data_structures -class UniqueNameTracker(data_structures.CheckpointableDataStructure): - """Adds dependencies on checkpointable objects with name hints. +class UniqueNameTracker(data_structures.TrackableDataStructure): + """Adds dependencies on trackable objects with name hints. Useful for creating dependencies with locally unique names. @@ -43,30 +43,30 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): def __init__(self): super(UniqueNameTracker, self).__init__() - self._maybe_initialize_checkpointable() + self._maybe_initialize_trackable() self._name_counts = {} @property def _values(self): return [dep.ref for dep in self._checkpoint_dependencies] - def track(self, checkpointable, base_name): - """Add a dependency on `checkpointable`. + def track(self, trackable, base_name): + """Add a dependency on `trackable`. Args: - checkpointable: An object to add a checkpoint dependency on. + trackable: An object to add a checkpoint dependency on. base_name: A name hint, which is uniquified to determine the dependency name. Returns: - `checkpointable`, for chaining. + `trackable`, for chaining. Raises: - ValueError: If `checkpointable` is not a checkpointable object. + ValueError: If `trackable` is not a trackable object. """ - if not isinstance(checkpointable, checkpointable_lib.CheckpointableBase): + if not isinstance(trackable, trackable_lib.Trackable): raise ValueError( - ("Expected a checkpointable value, got %s which does not inherit " - "from CheckpointableBase.") % (checkpointable,)) + ("Expected a trackable value, got %s which does not inherit " + "from tf.track.Trackable.") % (trackable,)) def _format_name(prefix, number): if number > 0: @@ -80,5 +80,5 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): count += 1 candidate = _format_name(base_name, count) self._name_counts[base_name] = count + 1 - self._track_value(checkpointable, name=candidate) - return checkpointable + self._track_value(trackable, name=candidate) + return trackable diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py index ac85c7be803cd4c2f8ba19d3ef887a3c65a15933..bace21939602666aa48a05d2abfe05ae6aae41e2 100644 --- a/tensorflow/contrib/checkpoint/python/containers_test.py +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -26,9 +26,9 @@ from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import data_structures -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import data_structures +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util class UniqueNameTrackerTests(test.TestCase): @@ -52,7 +52,7 @@ class UniqueNameTrackerTests(test.TestCase): save_root = util.Checkpoint(slots=slots) save_path = save_root.save(checkpoint_prefix) - restore_slots = tracking.Checkpointable() + restore_slots = tracking.AutoTrackable() restore_root = util.Checkpoint( slots=restore_slots) status = restore_root.restore(save_path) @@ -68,7 +68,7 @@ class UniqueNameTrackerTests(test.TestCase): @test_util.run_in_graph_and_eager_modes def testExample(self): - class SlotManager(tracking.Checkpointable): + class SlotManager(tracking.AutoTrackable): def __init__(self): self.slotdeps = containers.UniqueNameTracker() diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py index 302d5cfb79a08b6adf52ebd44533152c5454eadc..737a6c30c1dce65dd7638ee52e6c26a8a40f8321 100644 --- a/tensorflow/contrib/checkpoint/python/python_state.py +++ b/tensorflow/contrib/checkpoint/python/python_state.py @@ -23,7 +23,7 @@ import six import numpy -from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.tracking import base # pylint: disable=g-import-not-at-top try: @@ -34,8 +34,8 @@ except ImportError: # pylint: enable=g-import-not-at-top -class NumpyState(base.CheckpointableBase): - """A checkpointable object whose NumPy array attributes are saved/restored. +class NumpyState(base.Trackable): + """A trackable object whose NumPy array attributes are saved/restored. Example usage: @@ -72,7 +72,7 @@ class NumpyState(base.CheckpointableBase): """Create placeholder NumPy arrays for to-be-restored attributes. Typically `_lookup_dependency` is used to check by name whether a dependency - exists. We cheat slightly by creating a checkpointable object for `name` if + exists. We cheat slightly by creating a trackable object for `name` if we don't already have one, giving us attribute re-creation behavior when loading a checkpoint. @@ -85,7 +85,7 @@ class NumpyState(base.CheckpointableBase): value = super(NumpyState, self)._lookup_dependency(name) if value is None: value = _NumpyWrapper(numpy.array([])) - new_reference = base.CheckpointableReference(name=name, ref=value) + new_reference = base.TrackableReference(name=name, ref=value) self._unconditional_checkpoint_dependencies.append(new_reference) self._unconditional_dependency_names[name] = value super(NumpyState, self).__setattr__(name, value) @@ -101,7 +101,7 @@ class NumpyState(base.CheckpointableBase): def __setattr__(self, name, value): """Automatically wrap NumPy arrays assigned to attributes.""" # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making - # ndarrays checkpointable natively and using standard checkpointable list + # ndarrays trackable natively and using standard trackable list # tracking. if isinstance(value, (numpy.ndarray, numpy.generic)): try: @@ -110,19 +110,19 @@ class NumpyState(base.CheckpointableBase): return except AttributeError: value = _NumpyWrapper(value) - self._track_checkpointable(value, name=name, overwrite=True) + self._track_trackable(value, name=name, overwrite=True) elif (name not in ("_setattr_tracking", "_update_uid") and getattr(self, "_setattr_tracking", True)): - # Mixing restore()-created attributes with user-added checkpointable + # Mixing restore()-created attributes with user-added trackable # objects is tricky, since we can't use the `_lookup_dependency` trick to # re-create attributes (we might accidentally steal the restoration for - # another checkpointable object). For now `NumpyState` objects must be + # another trackable object). For now `NumpyState` objects must be # leaf nodes. Theoretically we could add some extra arguments to # `_lookup_dependency` to figure out whether we should create a NumPy # array for the attribute or not. raise NotImplementedError( ("Assigned %s to the %s property of %s, which is not a NumPy array. " - "Currently mixing NumPy arrays and other checkpointable objects is " + "Currently mixing NumPy arrays and other trackable objects is " "not supported. File a feature request if this limitation bothers " "you.") % (value, name, self)) @@ -130,7 +130,7 @@ class NumpyState(base.CheckpointableBase): @six.add_metaclass(abc.ABCMeta) -class PythonStateWrapper(base.CheckpointableBase): +class PythonStateWrapper(base.Trackable): """Wraps a Python object for storage in an object-based checkpoint.""" @abc.abstractmethod diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py index 45494351ff4e6c8c75634d8563c3fb63c6089036..40d8fe836402c8b6c8240ef9f665b753c54ede0d 100644 --- a/tensorflow/contrib/checkpoint/python/python_state_test.py +++ b/tensorflow/contrib/checkpoint/python/python_state_test.py @@ -26,7 +26,7 @@ from tensorflow.python.eager import test from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import variables -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import util class NumpyStateTests(test.TestCase): diff --git a/tensorflow/contrib/checkpoint/python/split_dependency.py b/tensorflow/contrib/checkpoint/python/split_dependency.py index 7e77453f3d848c2e321ed2ba66917a742d95459a..d7b02b538909305b14e638761bd8ba67a948d2b4 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency.py @@ -21,7 +21,7 @@ import functools from tensorflow.python.ops import control_flow_ops from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): @@ -43,7 +43,7 @@ class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): return self._restore_callback(tensor) -class _SplitDependency(checkpointable.CheckpointableBase): +class _SplitDependency(trackable.Trackable): """Looks like a regular variable while synchronizing save/restores.""" def __init__(self, save_buffer, restore_buffer, name, dtype, num_components, @@ -81,9 +81,9 @@ class _SplitDependency(checkpointable.CheckpointableBase): return control_flow_ops.no_op() def _gather_saveables_for_checkpoint(self): - """Looks to Checkpointable like a regular variable.""" + """Looks to Trackable like a regular variable.""" return { - checkpointable.VARIABLE_VALUE_KEY: + trackable.VARIABLE_VALUE_KEY: functools.partial(_CallbackSaveable, dtype=self._dtype, save_callback=self._save, @@ -117,7 +117,7 @@ def split_dependency(component_names, component_dtypes, may return `None`). Returns: - A dictionary mapping from names to Checkpointable objects. If one is + A dictionary mapping from names to Trackable objects. If one is reachable from an object as a dependency, the others should be too; adding dependencies on some but not all of the objects will result in errors. """ diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py index 00a805af25d5d0ea723db5d015fb12bf45c53857..9bc01059481ff69064e3f9c682a764146b79a250 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -23,9 +23,9 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training.checkpointable import base -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import base +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util def _split_variable_closure(variable): @@ -44,7 +44,7 @@ def _combine_variable_closure(variable): return _consume_restore_buffer_fn -class SaveTensorSlicesAsDeps(base.CheckpointableBase): +class SaveTensorSlicesAsDeps(base.Trackable): def __init__(self): self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) @@ -56,17 +56,17 @@ class SaveTensorSlicesAsDeps(base.CheckpointableBase): consume_restore_buffer_fn=_combine_variable_closure( self.combined)) for name, dep in split_dependencies.items(): - self._track_checkpointable(dep, name=name) + self._track_trackable(dep, name=name) -class HasRegularDeps(tracking.Checkpointable): +class HasRegularDeps(tracking.AutoTrackable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) -class OnlyOneDep(tracking.Checkpointable): +class OnlyOneDep(tracking.AutoTrackable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py index bac071c4cff383f60b707b6e42c13faf5e0ac948..faf90f018476b3c70a7bfa1346a5b590edbbddcd 100644 --- a/tensorflow/contrib/checkpoint/python/visualize.py +++ b/tensorflow/contrib/checkpoint/python/visualize.py @@ -18,8 +18,8 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.training.tracking import util as trackable_utils def dot_graph_from_checkpoint(save_path): @@ -51,7 +51,7 @@ def dot_graph_from_checkpoint(save_path): A graph in DOT format as a string. """ reader = pywrap_tensorflow.NewCheckpointReader(save_path) - object_graph = checkpointable_utils.object_metadata(save_path) + object_graph = trackable_utils.object_metadata(save_path) shape_map = reader.get_variable_to_shape_map() dtype_map = reader.get_variable_to_dtype_map() graph = 'digraph {\n' @@ -63,7 +63,7 @@ def dot_graph_from_checkpoint(save_path): slot_ids.add(slot_reference.slot_variable_node_id) for node_id, node in enumerate(object_graph.nodes): if (len(node.attributes) == 1 - and node.attributes[0].name == checkpointable.VARIABLE_VALUE_KEY): + and node.attributes[0].name == trackable.VARIABLE_VALUE_KEY): if node_id in slot_ids: color = 'orange' tooltip_prefix = 'Slot variable' diff --git a/tensorflow/contrib/checkpoint/python/visualize_test.py b/tensorflow/contrib/checkpoint/python/visualize_test.py index 583e3bc442893d825c337d73fb999d1e586738a1..98a22d573fdb6172cde100df461d9ae520c2c483 100644 --- a/tensorflow/contrib/checkpoint/python/visualize_test.py +++ b/tensorflow/contrib/checkpoint/python/visualize_test.py @@ -28,7 +28,7 @@ from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import core from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import adam -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils try: import pydot # pylint: disable=g-import-not-at-top @@ -57,7 +57,7 @@ class DotGraphTests(test.TestCase): model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = resource_variable_ops.ResourceVariable(12) - save_checkpoint = checkpointable_utils.Checkpoint( + save_checkpoint = trackable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) optimizer.minimize(functools.partial(model, input_value)) checkpoint_directory = self.get_temp_dir() diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 1311063ec023bdaa2588d6f1c826bf900f7dea09..20f8c2b2453a58fdbe5a3587fa6687debd9c06d3 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -27,7 +27,6 @@ tf_kernel_library( deps = [ ":bigquery_table_accessor", ":bigquery_table_partition_proto_cc", - "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:reader_base", @@ -79,7 +78,6 @@ tf_kernel_library( srcs = ["gcs_config_ops.cc"], visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/contrib/cloud:gcs_config_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/platform/cloud:curl_http_request", diff --git a/tensorflow/contrib/cluster_resolver/__init__.py b/tensorflow/contrib/cluster_resolver/__init__.py index 390b3e7550b3d991269bb84707c3500f2fa33290..a4dea85efd98893c881abbd3f7ebda78755b8189 100644 --- a/tensorflow/contrib/cluster_resolver/__init__.py +++ b/tensorflow/contrib/cluster_resolver/__init__.py @@ -23,7 +23,7 @@ from __future__ import print_function from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver -from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver @@ -36,7 +36,7 @@ _allowed_symbols = [ 'ClusterResolver', 'SimpleClusterResolver', 'UnionClusterResolver', - 'GceClusterResolver', + 'GCEClusterResolver', 'KubernetesClusterResolver', 'TFConfigClusterResolver', 'TPUClusterResolver', diff --git a/tensorflow/contrib/cluster_resolver/python/training/__init__.py b/tensorflow/contrib/cluster_resolver/python/training/__init__.py index 10d93549ebbd4f7e900796d0516b0af1744224af..ef1e9f11a07a5be6c0b181f5e0b80e0e2214f972 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/__init__.py +++ b/tensorflow/contrib/cluster_resolver/python/training/__init__.py @@ -25,7 +25,7 @@ from __future__ import print_function from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver -from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver @@ -43,7 +43,7 @@ _allowed_symbols = [ 'ClusterResolver', 'SimpleClusterResolver', 'UnionClusterResolver', - 'GceClusterResolver', + 'GCEClusterResolver', 'KubernetesClusterResolver', 'TFConfigClusterResolver', 'TPUClusterResolver', diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py index 55e61155c683c928efab9bb018868faec3e3df8c..5b49116ff6a4e17a774ea79b33ae1b948ba9f187 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Stub file for GceClusterResolver to maintain backwards compatibility.""" +"""Stub file for GCEClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division @@ -23,13 +23,14 @@ from __future__ import print_function # existing OSS code will not be broken. # pylint: disable=unused-import -from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented + _allowed_symbols = [ - 'GceClusterResolver', + 'GCEClusterResolver', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index df8b48dfc46124d3b9454d92ffb70dbcf1bc4217..60ee1b4b3fd7d0b6afaefcc05effd3bbae00cf2c 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -147,19 +147,19 @@ suitable interface for project configuration and dependency setting. * Go (required if you need ssl support, optional) * NASM/YASM (required by grpc for ssl support, optional) 2. Start CMake GUI -3. Click on `Browse Source` and direct to the the folder +3. Click on `Browse Source` and direct to the folder `/tensorflow/contrib/cmake` 4. Click on `Browse Build` and spectify a location that you want tensorflow to be build 5. Click on `Configure`, a new window will be prompted out, specify the generator mode for the project generation. For Windows, choose `Visual Studio Win64`, for Linux, choose `Unix Makefiles`, then - press `Finish`. Wait for a moment, the default project dependecy would + press `Finish`. Wait for a moment, the default project dependency would automatically generate. 6. There are a few options that you can customize your own build. **The setting - here is crucial for a sucessful build, please check all items carefully.** + here is crucial for a successful build, please check all items carefully.** - * `tensorflow_BUILD_ALL_KERNELS` should alway be `on` + * `tensorflow_BUILD_ALL_KERNELS` should always be `on` * `tensorflow_BUILD_CC_EXAMPLE` is default to be `on`. This can help you to test build (optional) * `tensorflow_BUILD_CONTRIB_KERNELS` is default to be `on`, but it won't @@ -278,7 +278,7 @@ suitable interface for project configuration and dependency setting. `make -sj install` Where `` is the threads used for the compilation, change - to any integer less or equal to your computer's maxiumum thread number. + to any integer less or equal to your computer's maximum thread number. Headers are discretely located in the build folders. Tensorflow library can be found at ``, namely `tensorflow.so` (Linux) or diff --git a/tensorflow/contrib/cmake/external/abseil_cpp.cmake b/tensorflow/contrib/cmake/external/abseil_cpp.cmake index 46a193971c5084523d432065f265fa7a9909f595..6c6a5df7f76723800740a81ccdcb137a0ec33846 100644 --- a/tensorflow/contrib/cmake/external/abseil_cpp.cmake +++ b/tensorflow/contrib/cmake/external/abseil_cpp.cmake @@ -31,17 +31,17 @@ if (systemlib_ABSEIL_CPP) message(STATUS " abseil_cpp includes: ${ABSEIL_CPP_INCLUDE_DIR}") message(STATUS " abseil_cpp libraries: ${ABSEIL_CPP_LIBRARIES}") - add_custom_target(abseil_cpp) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) + add_custom_target(abseil_cpp_build) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) else (systemlib_ABSEIL_CPP) include (ExternalProject) - set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp) - set(abseil_cpp_URL https://github.com/abseil/abseil-cpp/archive/e01d95528ea2137a4a27a88d1f57c6cb260aafed.tar.gz) - set(abseil_cpp_HASH SHA256=84043ed402d2a2a6ba4cdddb7e85118b1158fd81fe4ac3a14adc343d054c1e2e) - set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp-build) + set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) + set(abseil_cpp_URL https://github.com/abseil/abseil-cpp.git) + set(abseil_cpp_TAG master) + set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") @@ -49,8 +49,11 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/base/Release/absl_base.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_dynamic_annotations.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_internal_malloc_internal.lib + ${abseil_cpp_BUILD}/absl/base/Release/absl_internal_throw_delegate.lib + ${abseil_cpp_BUILD}/absl/numeric/Release/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/Release/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/Release/str_format_internal.lib + ${abseil_cpp_BUILD}/absl/time/Release/absl_time.lib ${abseil_cpp_BUILD}/absl/types/Release/absl_bad_optional_access.lib) else() set(abseil_cpp_STATIC_LIBRARIES @@ -62,6 +65,7 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/numeric/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/str_format_internal.lib + ${abseil_cpp_BUILD}/absl/time/absl_time.lib ${abseil_cpp_BUILD}/absl/types/absl_bad_optional_access.lib) endif() else() @@ -74,15 +78,18 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/numeric/libabsl_int128.a ${abseil_cpp_BUILD}/absl/strings/libabsl_strings.a ${abseil_cpp_BUILD}/absl/strings/libstr_format_internal.a + ${abseil_cpp_BUILD}/absl/time/libabsl_time.a ${abseil_cpp_BUILD}/absl/types/libabsl_bad_optional_access.a) endif() - ExternalProject_Add(abseil_cpp + ExternalProject_Add(abseil_cpp_build PREFIX abseil_cpp - URL ${abseil_cpp_URL} - URL_HASH ${abseil_cpp_HASH} + GIT_REPOSITORY ${abseil_cpp_URL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${abseil_cpp_STATIC_LIBRARIES} + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release + COMMAND ${CMAKE_COMMAND} --build . --config Release INSTALL_COMMAND "" CMAKE_CACHE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} @@ -91,8 +98,10 @@ else (systemlib_ABSEIL_CPP) ) include_directories(${abseil_cpp_INCLUDE_DIR}) + message(STATUS ${abseil_cpp_INCLUDE_DIR}) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${abseil_cpp_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) endif (systemlib_ABSEIL_CPP) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index e570c09ecb5e64130ed6f3375a51d74850cc3989..30b4e2dbdee1117df12ae7ab8ce902e667234fb0 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG 69b6c047bc767b4d80e7af4d00ccb7c45b683dae) +set(GRPC_TAG 62688b6a05cc85b47fb77dd408611734253e47e2) if(WIN32) # We use unsecure gRPC because boringssl does not build on windows diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index 479609458c64f7c7bd7b3ce6b23aceaa3db17f21..b15143bfc1cd787b156c9d6dd724a17730f0f8fb 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 1.20.1) +set(nsync_TAG 1.20.2) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 96160568fa79291a7b391761373e1eaf0f70974e..fd205a4b9b065a4756fbe3985694bb64b93b85e6 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -1,6 +1,9 @@ # python_sanity_test.py will complain about invalid or missing entries # problematic entries can be commented for temporary whitelisting tensorflow +tensorflow/compiler +tensorflow/compiler/xla +tensorflow/compiler/xla/service tensorflow/core tensorflow/core/example tensorflow/core/framework @@ -10,6 +13,7 @@ tensorflow/core/lib tensorflow/core/lib/core tensorflow/core/profiler tensorflow/core/protobuf +tensorflow/core/protobuf/tpu tensorflow/core/util tensorflow/examples tensorflow/examples/tutorials @@ -67,8 +71,9 @@ tensorflow/python/summary/writer tensorflow/python/tools tensorflow/python/tools/api tensorflow/python/tools/api/generator +tensorflow/python/tpu tensorflow/python/training -tensorflow/python/training/checkpointable +tensorflow/python/training/tracking tensorflow/python/user_ops tensorflow/python/util tensorflow/python/util/protobuf @@ -434,7 +439,6 @@ tensorflow/contrib/timeseries/python/timeseries/state_space_models tensorflow/contrib/tpu tensorflow/contrib/tpu/ops tensorflow/contrib/tpu/profiler -tensorflow/contrib/tpu/proto tensorflow/contrib/tpu/python tensorflow/contrib/tpu/python/ops tensorflow/contrib/tpu/python/profiler diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index 013180c89083748b240ad061b342300e886d3568..b4603206da419f44af0857b9b933eb7df1b255ff 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -1,6 +1,7 @@ tensorflow/core tensorflow/core/kernels/boosted_trees tensorflow/core/profiler +tensorflow/core/protobuf/tpu tensorflow/python tensorflow/contrib/boosted_trees/proto tensorflow/contrib/cloud/kernels @@ -12,7 +13,6 @@ tensorflow/contrib/mpi_collectives tensorflow/contrib/session_bundle tensorflow/contrib/tensor_forest/proto tensorflow/contrib/tensorboard/plugins/projector -tensorflow/contrib/tpu/proto tensorflow/contrib/tpu/profiler tensorflow/contrib/training/python/training tensorflow/contrib/verbs diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index d7b2a1339e047aba0a9424a53a63726805e89721..cc263d7995c01100f1c51436bcb584b600c8c161 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -125,9 +125,9 @@ endfunction() file(GLOB_RECURSE tf_protos_cc_srcs RELATIVE ${tensorflow_source_dir} "${tensorflow_source_dir}/tensorflow/core/*.proto" + "${tensorflow_source_dir}/tensorflow/core/protobuf/tpu/*.proto" "${tensorflow_source_dir}/tensorflow/compiler/xla/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tpu/proto/*.proto" ) RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS @@ -147,7 +147,6 @@ set(tf_proto_text_srcs "tensorflow/core/framework/function.proto" "tensorflow/core/framework/graph.proto" "tensorflow/core/framework/graph_transfer_info.proto" - "tensorflow/core/framework/iterator.proto" "tensorflow/core/framework/kernel_def.proto" "tensorflow/core/framework/log_memory.proto" "tensorflow/core/framework/node_def.proto" @@ -302,8 +301,8 @@ file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_factory.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_options.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.h" + "${tensorflow_source_dir}/tensorflow/core/summary/*.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/*.h" "${tensorflow_source_dir}/public/*.h" ) @@ -317,14 +316,14 @@ file(GLOB_RECURSE tf_core_framework_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/util/*test*.h" "${tensorflow_source_dir}/tensorflow/core/util/*test*.cc" "${tensorflow_source_dir}/tensorflow/core/util/*main.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*test*.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/loader.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/vacuum.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/*test*.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/loader.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/vacuum.cc" ) # TODO(jart): Why doesn't this work? # set_source_files_properties( -# ${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/snapfn.cc +# ${tensorflow_source_dir}/tensorflow/core/lib/db/snapfn.cc # PROPERTIES COMPILE_FLAGS -DSQLITE_OMIT_LOAD_EXTENSION) list(REMOVE_ITEM tf_core_framework_srcs ${tf_core_framework_exclude_srcs}) diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 8faccf8d55902e6701ebb4ce534b84705304fd5f..1fe8795ddf00232eba5a60a130e0845a6f6a8e17 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -802,6 +802,7 @@ add_custom_command( # tensorflow/__init__.py depends on files generated in this step. So, remove it while # this step is running since the files aren't there yet. COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py # Run create_python_api.py to generate API init files. COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python "${PY_RUNTIME_ENV}" ${PYTHON_EXECUTABLE} diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index e4566437c60ebb2da039e61c171fbe954a7355c9..79c61589112b739837b401010690e7f4ca917d07 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -23,6 +23,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":xla", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", ], @@ -70,22 +71,30 @@ py_library( ], ) -tf_py_test( +cuda_py_test( name = "xla_test", srcs = ["xla_test.py"], additional_deps = [ ":xla", - "@six_archive//:six", + "@absl_py//absl/testing:parameterized", + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/contrib/tpu:tpu_estimator", + "//tensorflow/contrib/tpu:tpu_lib", + "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_util", "//tensorflow/python:math_ops", "//tensorflow/python:platform", - "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + ], + tags = [ + "no_mac", + "no_windows", ], - tags = ["no_pip"], + xla_enabled = True, ) diff --git a/tensorflow/contrib/compiler/__init__.py b/tensorflow/contrib/compiler/__init__.py index c4937dadfb8be3211377f0ae7017b95e7642dab0..797e5e8164e231e8b3806d40b32774711879b050 100644 --- a/tensorflow/contrib/compiler/__init__.py +++ b/tensorflow/contrib/compiler/__init__.py @@ -19,3 +19,4 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.compiler import jit +from tensorflow.contrib.compiler import xla diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index f867cd15b67dbd43650d8012b4299845af7200a8..238c6ab1366a50710efabea2f33eb1bd06fe9423 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import function_utils +from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -76,10 +77,22 @@ def compile(computation, inputs=None): # pylint: disable=redefined-builtin All `Operation`s returned from `computation` will be executed when evaluating any of the returned output tensors. - inputs: A list of input tensors or `None` (equivalent to an empty list). + inputs: A list of inputs or `None` (equivalent to an empty list). Each input + can be a nested structure containing values that are convertible to + tensors. Note that passing an N-dimension list of compatible values will + result in a N-dimention list of scalar tensors rather than a single Rank-N + tensors. If you need different behavior, convert part of inputs to tensors + with `tf.convert_to_tensor`. Returns: - A list of output tensors. + Same data structure as if computation(*inputs) is called directly with some + exceptions for correctness. Exceptions include: + 1) None output: a NoOp would be returned which control-depends on + computation. + 2) Single value output: A tuple containing the value would be returned. + 3) Operation-only outputs: a NoOp would be returned which + control-depends on computation. + TODO(b/121383831): Investigate into removing these special cases. """ # pylint: disable=protected-access return _compile_internal(computation, inputs) @@ -131,6 +144,30 @@ class XLACompileContext(control_flow_ops.XLAControlFlowContext): logging.warning('... and %d more', len(self._unsupported_ops) - _MAX_WARNING_LINES) + def _RemoveExternalControlEdges(self, op): + """Remove any external control dependency on this op.""" + internal_control_inputs = [] + external_control_inputs = [] + for x in op.control_inputs: + # pylint: disable=protected-access + is_internal_op = False + ctxt = x._get_control_flow_context() + while ctxt is not None: + if ctxt == self: + is_internal_op = True + break + ctxt = ctxt._outer_context + if is_internal_op: + internal_control_inputs.append(x) + else: + external_control_inputs.append(x) + # pylint: enable=protected-access + # pylint: disable=protected-access + op._remove_all_control_inputs() + op._add_control_inputs(internal_control_inputs) + # pylint: enable=protected-access + return internal_control_inputs, external_control_inputs + def AddOp(self, op): """Create op in XLACompileContext and notifies outer context recursively.""" # pylint: disable=protected-access @@ -180,11 +217,14 @@ class XLACompileContext(control_flow_ops.XLAControlFlowContext): if external_control_inputs: # Use an identity to pull control inputs as data inputs. Note that we # ignore ops which don't have outputs. TODO(phawkins): fix that. - external_control_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_control_inputs - if x.outputs - ] + with ops.control_dependencies(None): + self.Enter() + external_control_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_control_inputs + if x.outputs + ] + self.Exit() # pylint: disable=protected-access op._add_control_inputs(external_control_inputs) # pylint: enable=protected-access @@ -245,13 +285,21 @@ def _compile_internal(computation, inputs=None): Args: computation: A Python function that builds the computation to compile and execute. - inputs: A list of input tensors or `None` (equivalent to `[]`). Its order - should match ordering of computation arguments. + inputs: A list of inputs or `None` (equivalent to an empty list). Each input + can be a nested structure containing values that are convertible to + tensors. Note that passing an N-dimension list of compatible values will + result in a N-dimension list of scalar tensors rather than a single Rank-N + tensors. If you need different behavior, convert part of inputs to tensors + with `tf.convert_to_tensor`. + Returns: - A list of output tensors from computation. + Same data structure as if computation(*inputs) is called directly with some + exceptions for correctness. Exceptions include: 1) None output 2) Single + value output 3) Operation-only outputs Raises: ValueError: If any element in computation outputs is neither an operations or a value that can be converted to tensor. + ValueError: If computation outputs is non-flat and contains any Operations. TypeError: If `inputs` is not a list or tuple. """ if inputs is None: @@ -260,17 +308,10 @@ def _compile_internal(computation, inputs=None): if not isinstance(inputs, collections.Sequence): raise TypeError('inputs must be a list') + # Flatten inputs. + flat_inputs = nest.flatten(inputs) # Converts inputs to Tensors. - inputs = [ops.convert_to_tensor(x) for x in inputs] - input_arity = len(inputs) - - arg_error = check_function_argument_count( - computation, input_arity, infeed_queue=None) - if arg_error is not None: - raise TypeError( - 'Supplied computation cannot be called with the specified inputs. You ' - 'specified %d inputs: %s, but the computation needs %s' % - (input_arity, str([i.name for i in inputs]), arg_error)) + flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs] cluster_name = ops.get_default_graph().unique_name('cluster') pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') @@ -280,11 +321,15 @@ def _compile_internal(computation, inputs=None): # Add identity ops so even unused inputs are 'consumed' by the # computation. - computation_inputs = [ + flat_inputs = [ array_ops.identity(x, name='input_{}'.format(i)) - for i, x in enumerate(inputs) + for i, x in enumerate(flat_inputs) ] + # Re-pack flat_inputs in same structure as 'inputs'. + computation_inputs = nest.pack_sequence_as( + structure=inputs, flat_sequence=flat_inputs) + # Only resource variables work inside an XLA computation, so turn on # resource variables for the computation. vscope = variable_scope.get_variable_scope() @@ -297,66 +342,166 @@ def _compile_internal(computation, inputs=None): # Restore variable scope after computation. vscope.set_use_resource(saved_use_resource) - # If the computation returns `None`, make it an empty tuple. - if outputs is None: - outputs = tuple() - # If the computation only returned one value, make it a tuple. - if not isinstance(outputs, collections.Sequence): - outputs = (outputs,) - - # Append `no_op` here so that return value of this function always contains - # at least one op that can trigger XlaLaunch node. - outputs += (control_flow_ops.no_op(),) - try: - outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) - for o in outputs - ] - except Exception as e: - raise ValueError( - 'XLA computation function return values must all either be Operations' - ' or convertible to Tensors. Got error: "%s"' % str(e)) - - # Separates the returned Operations and Tensors. - output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] - - if outputs != output_tensors + output_operations: - raise ValueError( - 'XLA computation function must return zero or more Tensor values ' - 'followed by zero or more Operations.') - output_arity = len(output_tensors) - - new_output_tensors = [] - for t in output_tensors: - with ops.device(t.device if t.device else ''): - new_output_tensors.append(array_ops.identity(t)) + outputs_is_flat = is_flat(outputs) + if outputs_is_flat: + output_tensors, control_deps = _postprocess_flat_outputs(outputs) + else: + output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) - output_tensors = new_output_tensors context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() - outputs = [ - xla_ops.xla_cluster_output(output_tensors[i], name='output{}'.format(i)) - for i in xrange(output_arity) + # When XLA computation returns only operations and no tensors, a NoOp + # dependent on the operations in outputs is returned. Otherwise final + # outputs would be empty and there is no way to trigger returned + # operations. + if not output_tensors: + return control_flow_ops.group(control_deps, name='output_0') + + output_tensors = [ + xla_ops.xla_cluster_output(o, name='output{}'.format(i)) + for i, o in enumerate(output_tensors) ] - with ops.control_dependencies(output_operations): - if output_arity == 0: - # When XLA computation returns only operations and no tensors, a NoOp - # dependent on the operations in outputs is returned. Otherwise final - # outputs would be empty and there is no way to trigger returned - # operations. - return control_flow_ops.no_op(name='output_0') - else: - # Wraps the outputs in identity operators that carries control - # dependencies. - return [ - array_ops.identity(outputs[i], name='output_%d' % i) - for i in xrange(output_arity) - ] + with ops.control_dependencies(control_deps): + # Wraps the outputs in identity operators that carries control + # dependencies. + output_tensors = [ + array_ops.identity(o, name='output_%d' % i) + for i, o in enumerate(output_tensors) + ] + + # If `computation` returned non-flat output structure, pack output tensors + # back into same structure. + if not outputs_is_flat: + output_tensors = nest.pack_sequence_as( + structure=outputs, flat_sequence=output_tensors) + + return output_tensors + + +def is_flat(outputs): + """Checks if outputs is a flat structure. + + Following structures and values are considered flat: + 1) None + 2) A single object + 3) A list or tuple of Tensors/Operations + + The only structures that this function understands are sequences and + dictionaries. E.g. this means that if outputs contains a single + user-defined Object, it is considered to be flat. Errors are raised later on + if that Object cannot be converted to a Tensor. + + Args: + outputs: Output from `computation` inside `xla.compile`. + + Returns: + A boolean indicates whether outputs is flat. + """ + # If outputs is a list or tuple, check if it has any nested structure. If + # there is, then outputs is non-flat. + if isinstance(outputs, collections.Sequence): + for o in outputs: + if isinstance(o, collections.Sequence) or isinstance(o, dict): + return False + + # If outputs is a dict, it is non-flat. + if isinstance(outputs, dict): + return False + + # Getting here means either outputs itself is a single non-structured value + # or it is a flat list of single non-structured values. + return True + + +def _postprocess_flat_outputs(outputs): + """Validates flat outputs and adds back device assignments. + + Args: + outputs: Output from `computation` inside `xla.compile`. + + Returns: + Tensors and Operations extracted from outputs. + """ + # Following code segment is to preserve legacy behavior. Previously we only + # supported flat outputs and thus for consistency it was nice to convert even + # single element into a tuple. But now that we support arbitrary output + # structure, this is no longer necessary. + # TODO(b/121383831): Migrate all legacy use cases and delete this special + # case. + # If the computation returns `None`, make it an empty tuple. + if outputs is None: + outputs = tuple() + # If the computation only returned one value, make it a tuple. + if not isinstance(outputs, collections.Sequence): + outputs = (outputs,) + + # Append `no_op` here so that return value of this function always contains + # at least one op that can trigger XlaLaunch node. + outputs += (control_flow_ops.no_op(),) + try: + outputs = [ + o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) + for o in outputs + ] + except Exception as e: + raise ValueError( + 'XLA computation function return values must all either be Operations' + ' or convertible to Tensors. Got error: "%s"' % str(e)) + + # Separates the returned Operations and Tensors. + output_operations = [o for o in outputs if isinstance(o, ops.Operation)] + output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] + + if outputs != output_tensors + output_operations: + raise ValueError( + 'XLA computation function must return zero or more Tensor values ' + 'followed by zero or more Operations.') + + new_output_tensors = [] + for t in output_tensors: + with ops.device(t.device if t.device else ''): + new_output_tensors.append(array_ops.identity(t)) + + return new_output_tensors, output_operations + + +def _postprocess_non_flat_outputs(outputs): + """Validates non-flat outputs and adds back device assignments. + + Args: + outputs: Output from `computation` inside `xla.compile`. + + Returns: + Tensors extracted from outputs and an empty list because Operations are not + allowed in non-flat outputs.. + """ + # Convert all non-Operation outputs to Tensors. + new_output_tensors = [] + for o in nest.flatten(outputs): + if isinstance(o, ops.Operation): + raise ValueError( + 'xla.compile does not support Operation as return value in non-flat ' + 'output structure. You can set returned Operations as control ' + 'dependencies of returned Tensors so Operations are triggered when ' + 'Tensors are evaluated. Operation found: "%s"' % o.name) + + try: + o = ops.convert_to_tensor(o) + except Exception as e: + raise ValueError( + 'XLA computation function return values must all either be ' + 'Operations or convertible to Tensors. Got error: "%s"' % str(e)) + + # Makes sure even pass-through inputs/outputs are touched in compile + # context by creating an Identity node inside compile context. + with ops.device(o.device if o.device else ''): + new_output_tensors.append(array_ops.identity(o)) + + return new_output_tensors, [] @contextlib.contextmanager diff --git a/tensorflow/contrib/compiler/xla_test.py b/tensorflow/contrib/compiler/xla_test.py index 3b49755afcf0753d31c0ce506dce42709b1ee8bc..c4384dcde75035dc55e67bd503e348fe19b97025 100644 --- a/tensorflow/contrib/compiler/xla_test.py +++ b/tensorflow/contrib/compiler/xla_test.py @@ -18,11 +18,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re +from absl.testing import parameterized + from tensorflow.contrib.compiler import xla +from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.contrib.tpu.python.tpu import tpu_feed +from tensorflow.contrib.training.python.training import hparam from tensorflow.python import summary +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import logging_ops @@ -30,6 +38,14 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +from tensorflow.python.training import training + + +_TRAIN = model_fn_lib.ModeKeys.TRAIN +_EVAL = model_fn_lib.ModeKeys.EVAL +_EXPECTED_LOSS = 1 +_EXPECTED_FEATURE = 2 +_EXPECTED_LABEL = 3 class XLACompileContextTest(test.TestCase): @@ -252,5 +268,329 @@ class CheckFunctionArgumentCountTest(test.TestCase): xla.check_function_argument_count(func, 0, queue)) +def _test_train_model_fn(features, labels, mode, params): + """A dummy model_fn for testing purpose.""" + del features, labels, params + loss = constant_op.constant(_EXPECTED_LOSS) + return model_fn_lib.EstimatorSpec( + mode=mode, loss=loss, train_op=array_ops.identity(loss)) + + +@xla.estimator_model_fn +def decorated_model_fn(features, labels, mode, params): + return _test_train_model_fn(features, labels, mode, params) + + +def make_dummy_features_labels(): + # XLA CPU/GPU backend doesn't support guaranteed constant, thus use dataset + # container to work around. + features_dataset = dataset_ops.Dataset.from_tensors( + constant_op.constant(_EXPECTED_FEATURE)).repeat(10) + features_op = features_dataset.make_one_shot_iterator().get_next() + labels_dataset = dataset_ops.Dataset.from_tensors( + constant_op.constant(_EXPECTED_LABEL)).repeat(10) + labels_op = labels_dataset.make_one_shot_iterator().get_next() + return features_op, labels_op + + +class XlaDecoratorTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ('test_use_as_decorator', decorated_model_fn, None), + ('test_use_as_function', xla.estimator_model_fn(_test_train_model_fn), + None), + ('test_use_tpu_false_hparams', decorated_model_fn, + hparam.HParams(use_tpu=False)), + ('test_use_tpu_false_dict_params', decorated_model_fn, { + 'use_tpu': False + }), + ) + def test_compile(self, model_fn, params): + """Calls model_fn and verifies it is compiled.""" + with test.mock.patch.object(xla, 'compile') as mock_xla_compile: + loss = constant_op.constant(_EXPECTED_LOSS) + mock_xla_compile.return_value = [loss] + + features, labels = make_dummy_features_labels() + estimator_spec = model_fn( + features=features, labels=labels, mode=_TRAIN, params=params or {}) + + self.assertEqual(mock_xla_compile.call_count, 1) + self.assertEqual(estimator_spec.mode, _TRAIN) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), sess.run(loss)) + self.assertEqual(sess.run(estimator_spec.train_op), sess.run(loss)) + + @parameterized.named_parameters( + ('test_use_tpu_true_hparams', decorated_model_fn, + hparam.HParams(use_tpu=True)), + ('test_use_tpu_true_dict_params', decorated_model_fn, { + 'use_tpu': True + }), + ) + def test_not_compile(self, model_fn, params): + """Calls model_fn and verifies it is NOT compiled.""" + with test.mock.patch.object(xla, 'compile') as mock_xla_compile: + loss = constant_op.constant(_EXPECTED_LOSS) + mock_xla_compile.return_value = [loss] + + features, labels = make_dummy_features_labels() + estimator_spec = model_fn( + features=features, labels=labels, mode=_TRAIN, params=params or {}) + + mock_xla_compile.assert_not_called() + self.assertEqual(estimator_spec.mode, _TRAIN) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), sess.run(loss)) + self.assertEqual(sess.run(estimator_spec.train_op), sess.run(loss)) + + def test_model_with_summary(self): + """Tests that summary ops are disabled.""" + + @xla.estimator_model_fn + def model_fn_with_summary(features, labels, mode, params): + del features, labels, params + loss = constant_op.constant(_EXPECTED_LOSS) + summary.scalar('loss_scalar_summary', loss) + summary.histogram('loss_histogram_summary', loss) + summary.image('loss_image_summary', loss) + return model_fn_lib.EstimatorSpec( + mode=mode, loss=loss, train_op=array_ops.identity(loss)) + + features, labels = make_dummy_features_labels() + estimator_spec = model_fn_with_summary( + features=features, labels=labels, mode=_TRAIN, params={}) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + + +def _test_eval_metric_fn(eval_tensor_1, eval_tensor_2): + return { + 'metric_1': (eval_tensor_1, eval_tensor_1), + 'metric_2': (eval_tensor_2, eval_tensor_2), + } + + +class XlaDecoratorEvaluationTest(test.TestCase): + + def _verify_evaluation_result(self, eval_model_fn): + features, labels = make_dummy_features_labels() + estimator_spec = eval_model_fn( + features=features, labels=labels, mode=_EVAL, params={}) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + self.assertEqual( + sess.run(estimator_spec.eval_metric_ops['metric_1'][0]), + _EXPECTED_FEATURE + _EXPECTED_LABEL) + self.assertEqual( + sess.run(estimator_spec.eval_metric_ops['metric_1'][1]), + _EXPECTED_FEATURE + _EXPECTED_LABEL) + self.assertEqual( + sess.run(estimator_spec.eval_metric_ops['metric_2'][0]), + _EXPECTED_FEATURE - _EXPECTED_LABEL) + self.assertEqual( + sess.run(estimator_spec.eval_metric_ops['metric_2'][1]), + _EXPECTED_FEATURE - _EXPECTED_LABEL) + + def test_eval_base_estimator_spec_eval_metric_ops_disallowed(self): + + @xla.estimator_model_fn + def eval_model_fn_return_estimator_spec(features, labels, mode, params): + del features, labels, params + loss = constant_op.constant(_EXPECTED_LOSS) + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops={ + 'metric': (array_ops.identity(loss), control_flow_ops.no_op()) + }) + + with self.assertRaisesRegexp( + ValueError, 'EstimatorSpec.eval_metric_ops is not supported with XLA ' + 'compilation. Please use TPUEstimatorSpec.eval_metrics instead.'): + self._verify_evaluation_result(eval_model_fn_return_estimator_spec) + + def test_eval_base_estimator_spec_no_eval_metric_ops(self): + + @xla.estimator_model_fn + def eval_model_fn_no_eval_metric_ops(features, labels, mode, params): + del features, labels, params + return model_fn_lib.EstimatorSpec( + mode=mode, loss=constant_op.constant(_EXPECTED_LOSS)) + + features, labels = make_dummy_features_labels() + estimator_spec = eval_model_fn_no_eval_metric_ops( + features=features, labels=labels, mode=_EVAL, params={}) + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + + def test_eval_no_eval_metrics(self): + + @xla.estimator_model_fn + def eval_model_fn_no_eval_metrics(features, labels, mode, params): + del features, labels, params + return tpu_estimator.TPUEstimatorSpec( + mode=mode, loss=constant_op.constant(_EXPECTED_LOSS)) + + features, labels = make_dummy_features_labels() + estimator_spec = eval_model_fn_no_eval_metrics( + features=features, labels=labels, mode=_EVAL, params={}) + + self.assertEqual(estimator_spec.eval_metric_ops, {}) + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + + def test_eval_fn_missing_input_tensor(self): + + @xla.estimator_model_fn + def eval_model_fn(features, labels, mode, params): + del params + dummy_eval_metric_fn_tensors_dict = { + 'eval_tensor_1': features + labels, + } + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + eval_metrics=(_test_eval_metric_fn, + dummy_eval_metric_fn_tensors_dict)) + + with self.assertRaisesRegexp( + ValueError, + re.escape("Arguments ['eval_tensor_2'] are needed by metric_fn (first " + 'element of TPUEstimatorSpec.eval_metrics) but they are not ' + 'provided by evaluation tensors (second element of ' + 'TPUEstimatorSpec.eval_metrics).')): + self._verify_evaluation_result(eval_model_fn) + + def test_eval_fn_extraneous_input_tensor(self): + + @xla.estimator_model_fn + def eval_model_fn(features, labels, mode, params): + del params + dummy_eval_metric_fn_tensors_dict = { + 'eval_tensor_1': features + labels, + 'eval_tensor_2': features - labels, + 'extra_tensor': features * 2 - labels, + } + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + eval_metrics=(_test_eval_metric_fn, + dummy_eval_metric_fn_tensors_dict)) + + with self.assertRaisesRegexp( + ValueError, + re.escape("Arguments ['extra_tensor'] are provided by evaluation " + 'tensors (second element of TPUEstimatorSpec.eval_metrics) ' + 'but they are not needed by metric_fn (first element of ' + 'TPUEstimatorSpec.eval_metrics).')): + self._verify_evaluation_result(eval_model_fn) + + def test_eval_tensors_as_list(self): + + @xla.estimator_model_fn + def eval_model_fn(features, labels, mode, params): + del params + dummy_eval_metric_fn_tensors = [features + labels, features - labels] + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + eval_metrics=(_test_eval_metric_fn, dummy_eval_metric_fn_tensors)) + + self._verify_evaluation_result(eval_model_fn) + + def test_eval_tensors_as_dict(self): + + @xla.estimator_model_fn + def eval_model_fn(features, labels, mode, params): + del params + dummy_eval_metric_fn_tensors_dict = { + 'eval_tensor_1': features + labels, + 'eval_tensor_2': features - labels, + } + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + eval_metrics=(_test_eval_metric_fn, + dummy_eval_metric_fn_tensors_dict)) + + self._verify_evaluation_result(eval_model_fn) + + def test_model_with_summary(self): + """Tests that summary ops are disabled.""" + + @xla.estimator_model_fn + def model_fn_with_summary(features, labels, mode, params): + del features, labels, params + loss = constant_op.constant(_EXPECTED_LOSS) + summary.scalar('loss_scalar_summary', loss) + summary.histogram('loss_histogram_summary', loss) + summary.image('loss_image_summary', loss) + return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss) + + features, labels = make_dummy_features_labels() + estimator_spec = model_fn_with_summary( + features=features, labels=labels, mode=_EVAL, params={}) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + + +class XlaDecoratorScaffoldTest(test.TestCase, parameterized.TestCase): + + def _make_scaffold_fn(self, mode): + + def _scaffold_fn_on_cpu(): + scaffold = training.Scaffold() + self.assertNotIn(mode, self.is_scaffold_fn_called) + self.is_scaffold_fn_called[mode] = True + return scaffold + + return _scaffold_fn_on_cpu + + def test_scaffold_fn_return_none(self): + + @xla.estimator_model_fn + def model_fn(features, labels, mode, params): + del features, labels, params + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + train_op=control_flow_ops.no_op(), + scaffold_fn=lambda: None) + + features, labels = make_dummy_features_labels() + with self.assertRaisesRegexp( + ValueError, + 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed'): + model_fn(features=features, labels=labels, mode=_TRAIN, params={}) + + @parameterized.named_parameters( + ('train_mode', _TRAIN), + ('eval_mode', _EVAL), + # TODO(ycao): Add predict_mode test after PREDICT mode is implemented. + ) + def test_scaffold_fn_in_mode(self, mode): + + @xla.estimator_model_fn + def model_fn(features, labels, mode, params): + del features, labels, params + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + train_op=control_flow_ops.no_op(), + scaffold_fn=self._make_scaffold_fn(mode)) + + features, labels = make_dummy_features_labels() + + self.is_scaffold_fn_called = {} + model_fn(features=features, labels=labels, mode=mode, params={}) + self.assertTrue(self.is_scaffold_fn_called[mode]) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/constrained_optimization/README.md b/tensorflow/contrib/constrained_optimization/README.md index cb1dd7d836ae11700b2ffaaff4fda5b7f943f87d..7ffb6894d37444fd78015b6c124c46f2855c1cde 100644 --- a/tensorflow/contrib/constrained_optimization/README.md +++ b/tensorflow/contrib/constrained_optimization/README.md @@ -1,5 +1,10 @@ +**NOTE: As tensorflow.contrib is being +[deprecated](https://github.com/tensorflow/community/pull/18), TFCO is moving to +its own repository on +[github](https://github.com/google-research/tensorflow_constrained_optimization).** + # ConstrainedOptimization (TFCO) TFCO is a library for optimizing inequality-constrained problems in TensorFlow. diff --git a/tensorflow/contrib/constrained_optimization/python/candidates_test.py b/tensorflow/contrib/constrained_optimization/python/candidates_test.py index a4c49d48bc5c763489215261a909573af0f19055..280e9acd88638a9385bfd9128ba6d3739879aab2 100644 --- a/tensorflow/contrib/constrained_optimization/python/candidates_test.py +++ b/tensorflow/contrib/constrained_optimization/python/candidates_test.py @@ -52,12 +52,12 @@ class CandidatesTest(test.TestCase): distribution = candidates.find_best_candidate_distribution( objective_vector, constraints_matrix) # Verify that the solution is a probability distribution. - self.assertTrue(np.all(distribution >= 0)) + self.assertTrue(np.all(distribution >= -1e-6)) self.assertAlmostEqual(np.sum(distribution), 1.0) # Verify that the solution satisfies the constraints. maximum_constraint_violation = np.amax( np.dot(constraints_matrix, distribution)) - self.assertLessEqual(maximum_constraint_violation, 0) + self.assertLessEqual(maximum_constraint_violation, 1e-6) # Verify that the solution matches that which we expect. expected_distribution = np.array([0.37872711, 0.62127289, 0, 0]) self.assertAllClose(expected_distribution, distribution, rtol=0, atol=1e-6) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index a268415f0e65206294431a537be18cadbe1a1e84..f5219eb134d07c09b16a544f71d4c18986c19681 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -68,6 +68,7 @@ def RunLSTM(sess, batch_size, time, num_layers=1, + variable_seq_lengths=False, is_training=True, dropout=0., num_dirs=True, @@ -99,6 +100,13 @@ def RunLSTM(sess, num_units).astype(dtype.as_numpy_dtype), dtype=dtype) + if variable_seq_lengths: + lengths_v = np.random.randint(low=1, high=time + 1, size=batch_size) + lengths_v[0] = time # make sure the max sequence has 'time' elems + lengths = ops.convert_to_tensor(lengths_v.astype(np.int32)) + else: + lengths = None + initializer = init_ops.random_uniform_initializer( -0.01, 0.01, dtype=dtype, seed=19980904) @@ -115,6 +123,7 @@ def RunLSTM(sess, outputs_op, state_tuple_op = rnn.dynamic_rnn( cell, inputs, + sequence_length=lengths, initial_state=rnn_cell_impl.LSTMStateTuple( h=initial_h_op, c=initial_c_op), dtype=dtype, @@ -133,6 +142,7 @@ def RunLSTM(sess, cu_initial_h_op, cu_initial_c_op, opaque_params, + sequence_lengths=lengths, dropout=dropout, is_training=is_training, rnn_mode=cudnn_rnn_ops.CUDNN_LSTM) @@ -325,12 +335,19 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, dtype, - rtol=2e-6, - atol=2e-6): + variable_seq_lengths, + rtol=3e-6, + atol=3e-6): with self.session(use_gpu=True) as sess: (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunLSTM( - sess, num_units, input_size, batch_size, time, num_layers) + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) for s, cu_s in zip(state_tuple, cu_state_tuple): @@ -341,20 +358,33 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): self.assertAllClose(bgrad, cu_bgrad, rtol=rtol, atol=atol) self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def test_training(self, num_units, input_size, batch_size, time, num_layers): + def test_training(self, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") - self._test_training_helper(num_units, input_size, batch_size, time, - num_layers, dtypes.float32) + self._test_training_helper( + num_units, + input_size, + batch_size, + time, + num_layers, + dtypes.float32, + variable_seq_lengths=variable_seq_lengths) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training_fp16(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") self._test_training_helper( @@ -365,12 +395,17 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, dtypes.float16, rtol=5e-3, - atol=5e-4) + atol=5e-4, + variable_seq_lengths=variable_seq_lengths) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def test_inference(self, num_units, input_size, batch_size, time, num_layers): + def test_inference(self, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -381,7 +416,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): batch_size, time, num_layers, - is_training=False) + is_training=False, + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(outputs, cu_outputs) # h @@ -389,11 +425,14 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): # c self.assertAllClose(state_tuple.c, cu_state_tuple.c) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_fp16(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -405,7 +444,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dtype=dtypes.float16) + dtype=dtypes.float16, + variable_seq_lengths=variable_seq_lengths) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) @@ -416,11 +456,14 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): self.assertAllClose( state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_with_dropout(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): """Validates that dropout does not affect Cudnn Rnn inference.""" if not context.context().num_gpus(): self.skipTest("No GPUs found") @@ -436,7 +479,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dropout=0.) + dropout=0., + variable_seq_lengths=variable_seq_lengths) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -448,7 +492,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dropout=1.) + dropout=1., + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(cu_outputs, cu_outputs2) # h @@ -464,6 +509,7 @@ def RunGRU(sess, time, num_layers=1, is_training=True, + variable_seq_lengths=False, dropout=0., num_dirs=True, dtype=dtypes.float32): @@ -489,6 +535,13 @@ def RunGRU(sess, num_units).astype(dtype.as_numpy_dtype), dtype=dtype) + if variable_seq_lengths: + lengths_v = np.random.randint(low=1, high=time + 1, size=batch_size) + lengths_v[0] = time # make sure the max sequence has 'time' elems + lengths = ops.convert_to_tensor(lengths_v.astype(np.int32)) + else: + lengths = None + initializer = init_ops.random_uniform_initializer( -0.01, 0.01, dtype=dtype, seed=19980904) with variable_scope.variable_scope("test", initializer=initializer): @@ -521,6 +574,7 @@ def RunGRU(sess, outputs_op, h_op = rnn.dynamic_rnn( cell, inputs, + sequence_length=lengths, initial_state=initial_h_op, dtype=dtype, time_major=True, @@ -533,12 +587,14 @@ def RunGRU(sess, num_layers, num_units, input_size) opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) cu_outputs_op, cu_h_op, _ = cudnn_rnn_ops._cudnn_rnn( inputs, cu_initial_h_op, array_ops.zeros_like(cu_initial_h_op), # not used opaque_params, + sequence_lengths=lengths, dropout=dropout, is_training=is_training, rnn_mode=cudnn_rnn_ops.CUDNN_GRU) @@ -615,12 +671,19 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, dtype, - rtol=2e-6, - atol=2e-6): + variable_seq_lengths, + rtol=3e-6, + atol=3e-6): with self.session(use_gpu=True) as sess: - (outputs, cu_outputs, h, cu_h, inp_grad, cu_inp_grad, hgrad, - cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunGRU( - sess, num_units, input_size, batch_size, time, num_layers) + (outputs, cu_outputs, h, cu_h, inp_grad, cu_inp_grad, hgrad, cu_hgrad, + wgrad, bgrad, cu_wgrad, cu_bgrad) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) @@ -631,20 +694,33 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): for wg, cu_wg in zip(wgrad, cu_wgrad): self.assertAllClose(wg, cu_wg, rtol=rtol, atol=atol) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def test_training(self, num_units, input_size, batch_size, time, num_layers): + def test_training(self, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") - self._test_training_helper(num_units, input_size, batch_size, time, - num_layers, dtypes.float32) + self._test_training_helper( + num_units, + input_size, + batch_size, + time, + num_layers, + dtypes.float32, + variable_seq_lengths=variable_seq_lengths) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training_fp16(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") self._test_training_helper( @@ -655,12 +731,17 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, dtypes.float16, rtol=5e-3, - atol=5e-4) + atol=5e-4, + variable_seq_lengths=variable_seq_lengths) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def test_inference(self, num_units, input_size, batch_size, time, num_layers): + def test_inference(self, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -671,15 +752,19 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): batch_size, time, num_layers, - is_training=False) + is_training=False, + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(outputs, cu_outputs) self.assertAllClose(h, cu_h) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_fp16(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -691,17 +776,21 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dtype=dtypes.float16) + dtype=dtypes.float16, + variable_seq_lengths=variable_seq_lengths) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_with_dropout(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): """Validates that dropout does not affect Cudnn Rnn inference.""" # Hand-picked dropouts are used below (0. and 1.) if not context.context().num_gpus(): @@ -717,7 +806,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dropout=0.) + dropout=0., + variable_seq_lengths=variable_seq_lengths) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -729,7 +819,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dropout=1.) + dropout=1., + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(cu_outputs, cu_outputs2) self.assertAllClose(cu_h[0], cu_h2[0]) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 7e1b4062ce435f3ab4216e90b4f5fcbab984c1dc..403f30909520dc5cd5f5919af843291fe1400b91 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -58,7 +58,7 @@ from tensorflow.python.training import gradient_descent from tensorflow.python.training import momentum from tensorflow.python.training import rmsprop from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM @@ -709,7 +709,7 @@ class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase): self._TestSaveRestoreHelper(CUDNN_RNN_RELU) -class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): +class CudnnRNNTestSaveRestoreTrackable(test_util.TensorFlowTestCase): def _VerifyCheckpoint( self, checkpoint_path, compatible_cell_fn, cudnn_cell_fn, @@ -718,7 +718,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") with ops.device("gpu:0"): cudnn_layer = cudnn_cell_fn() - cudnn_checkpoint = checkpointable_utils.Checkpoint(cell=cudnn_layer) + cudnn_checkpoint = trackable_utils.Checkpoint(cell=cudnn_layer) status = cudnn_checkpoint.restore(checkpoint_path) inputs = 3. * array_ops.ones([num_applications, num_layers, input_size], dtype=dtypes.float32) @@ -726,7 +726,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): status.run_restore_ops() second_save_path = cudnn_checkpoint.save(checkpoint_prefix) restore_layer = compatible_cell_fn() - restore_layer_checkpoint = checkpointable_utils.Checkpoint( + restore_layer_checkpoint = trackable_utils.Checkpoint( cell=restore_layer) status = restore_layer_checkpoint.restore(second_save_path) current_state = restore_layer.zero_state(1, dtypes.float32) @@ -742,7 +742,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): self.assertAllClose(self.evaluate(restore_layer_output), self.evaluate(cudnn_output)[-1, -1:, ...]) - def _CheckpointableSingleCellUnidirectionalTestTemplate( + def _TrackableSingleCellUnidirectionalTestTemplate( self, single_cell_fn, cudnn_cell_fn): # Single-layer cuDNN cells with object-based checkpointing should be # checkpoint compatible with either single CudnnCompatible cells or @@ -759,7 +759,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): value = np.random.normal(size=variable.shape) expected_values.append(value) self.evaluate(variable.assign(value)) - save_checkpoint = checkpointable_utils.Checkpoint(cell=save_cell_layer) + save_checkpoint = trackable_utils.Checkpoint(cell=save_cell_layer) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") first_save_path = save_checkpoint.save(checkpoint_prefix) @@ -775,10 +775,10 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") @test_util.run_in_graph_and_eager_modes - def testLSTMCheckpointableSingleLayer(self): + def testLSTMTrackableSingleLayer(self): num_units = 2 direction = CUDNN_RNN_UNIDIRECTION - self._CheckpointableSingleCellUnidirectionalTestTemplate( + self._TrackableSingleCellUnidirectionalTestTemplate( single_cell_fn=functools.partial( cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units), cudnn_cell_fn=functools.partial( @@ -788,19 +788,19 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") @test_util.run_in_graph_and_eager_modes - def testGRUCheckpointableSingleLayer(self): + def testGRUTrackableSingleLayer(self): num_units = 2 direction = CUDNN_RNN_UNIDIRECTION with self.assertRaises(NotImplementedError): # TODO(allenl): Implement object-based saving for GRUs and other cells. - self._CheckpointableSingleCellUnidirectionalTestTemplate( + self._TrackableSingleCellUnidirectionalTestTemplate( single_cell_fn=functools.partial( cudnn_rnn_ops.CudnnCompatibleGRUCell, num_units=num_units), cudnn_cell_fn=functools.partial( cudnn_rnn.CudnnGRU, num_layers=1, num_units=num_units, direction=direction, name="awesome_gru")) - def _CheckpointableMultiLayerTestTemplate( + def _TrackableMultiLayerTestTemplate( self, single_cell_fn, cudnn_cell_fn, num_layers): def _MultiCellFn(): @@ -819,7 +819,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): value = np.random.normal(size=variable.shape) expected_values.append(value) self.evaluate(variable.assign(value)) - save_checkpoint = checkpointable_utils.Checkpoint(cell=save_layer) + save_checkpoint = trackable_utils.Checkpoint(cell=save_layer) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") first_save_path = save_checkpoint.save(checkpoint_prefix) @@ -837,7 +837,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): num_units = 2 num_layers = 3 direction = CUDNN_RNN_UNIDIRECTION - self._CheckpointableMultiLayerTestTemplate( + self._TrackableMultiLayerTestTemplate( single_cell_fn=functools.partial( cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units), cudnn_cell_fn=functools.partial( @@ -1023,7 +1023,7 @@ class CudnnRNNTestCompatibleRNNCells(test_util.TensorFlowTestCase): outputs_v, output_state_v = sess.run( [outputs, output_state], feed_dict={cell_inputs: inference_input}) - self.assertAllClose(cudnn_outputs_v, outputs_v, atol=2e-5, rtol=2e-5) + self.assertAllClose(cudnn_outputs_v, outputs_v, atol=1e-4, rtol=2e-4) (cudnn_output_h_v,) = cudnn_output_states_v self.assertAllClose(cudnn_output_h_v, output_state_v, atol=2e-5, rtol=2e-5) diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 8e25637ed91a1559b321ea96efbfaa2910f67158..1cb477716dfc6a9cc793939059784f9d89bcdd8a 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -374,7 +374,11 @@ class _CudnnRNN(base_layer.Layer): "This cell does not yet support object-based saving. File a feature " "request if this limitation bothers you.") - def call(self, inputs, initial_state=None, training=True): + def call(self, + inputs, + initial_state=None, + sequence_lengths=None, + training=True): """Runs the forward step for the RNN model. Args: @@ -382,6 +386,9 @@ class _CudnnRNN(base_layer.Layer): initial_state: a tuple of tensor(s) of shape `[num_layers * num_dirs, batch_size, num_units]`. If not provided, use zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs. + sequence_lengths: an int32 array representing the variable sequence + lengths in a batch. The size of the array has to equal the + batch_size. If not provided, the same sequence length will be assumed. training: whether this operation will be used in training or inference. Returns: output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`. @@ -411,7 +418,7 @@ class _CudnnRNN(base_layer.Layer): # For model that doesn't take input_c, replace with a dummy tensor. c = array_ops.constant([], dtype=dtype) outputs, (output_h, output_c) = self._forward(inputs, h, c, self.kernel, - training) + sequence_lengths, training) if self._rnn_mode == CUDNN_LSTM: return outputs, (output_h, output_c) else: @@ -475,7 +482,7 @@ class _CudnnRNN(base_layer.Layer): dropout=self._dropout, direction=self._direction) - def _forward(self, inputs, h, c, opaque_params, training): + def _forward(self, inputs, h, c, opaque_params, sequence_lengths, training): output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access inputs, h, @@ -483,6 +490,7 @@ class _CudnnRNN(base_layer.Layer): opaque_params, training, self._rnn_mode, + sequence_lengths=sequence_lengths, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, @@ -510,8 +518,8 @@ class _CudnnRNN(base_layer.Layer): direction=self.direction, scope=vs.get_variable_scope(), name="%s_saveable" % self.trainable_variables[0].name.split(":")[0]) - self._saveable._add_checkpointable_dependencies( # pylint: disable=protected-access - checkpointable=self, dtype=self._plain_dtype) + self._saveable._add_trackable_dependencies( # pylint: disable=protected-access + trackable=self, dtype=self._plain_dtype) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 1ce29b42d52ff67477161278ed11016c2e73041d..7d848e2ec2d99cd2a78ff3e813207c0cd5bb97cf 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -33,7 +33,7 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking as checkpointable_lib +from tensorflow.python.training.tracking import tracking as trackable_lib CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" @@ -737,13 +737,13 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): return state_ops.assign( self._variables, opaque_params, validate_shape=False) - def _checkpointable_save(self, save_buffer): + def _trackable_save(self, save_buffer): weights, biases = self.format_converter.opaque_to_tf_canonical( self._variables) for name, tensor in zip(self._param_names, weights + biases): save_buffer[name] = array_ops.identity(tensor) - def _checkpointable_restore(self, restore_buffer): + def _trackable_restore(self, restore_buffer): tensors = [ array_ops.identity(restore_buffer[name]) for name in self._param_names ] @@ -752,26 +752,26 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): restored_shapes=None # Unused ) - def _add_checkpointable_dependencies(self, checkpointable, dtype): - """Add canonical weight dependencies to `checkpointable`. + def _add_trackable_dependencies(self, trackable, dtype): + """Add canonical weight dependencies to `trackable`. When saving or restoring, converts to or from the opaque buffer format. Weights are saved and loaded in the configuration expected by cuDNN-compatible cells. Args: - checkpointable: An object inheriting from `CheckpointableBase` to add + trackable: An object inheriting from `Trackable` to add dependencies too (typically the cuDNN `Layer`). dtype: The dtype for the canonical parameter Tensors. """ split_dependencies = split_dependency.split_dependency( component_names=self._param_names, component_dtypes=(dtype,) * len(self._param_names), - fill_save_buffer_fn=self._checkpointable_save, - consume_restore_buffer_fn=self._checkpointable_restore) - self._checkpointable_track_params(checkpointable, split_dependencies) + fill_save_buffer_fn=self._trackable_save, + consume_restore_buffer_fn=self._trackable_restore) + self._trackable_track_params(trackable, split_dependencies) - def _checkpointable_track_params(self, checkpointable, params): + def _trackable_track_params(self, trackable, params): """Tracks parameters in a canonical configuration.""" return # NotImplementedError raised by the Layer. @@ -819,7 +819,7 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): tf_weights_names.append(prefix + "/kernel") tf_bias_names.append(prefix + "/bias") - def _checkpointable_track_params(self, checkpointable, params): + def _trackable_track_params(self, trackable, params): """Track parameters for compatibility with CudnnCompatibleLSTMCell.""" biases = [] weights = [] @@ -833,12 +833,12 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): # wrapping. kernel, = weights # pylint: disable=unbalanced-tuple-unpacking bias, = biases # pylint: disable=unbalanced-tuple-unpacking - checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access - checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access + trackable._track_trackable(kernel, name="kernel") # pylint: disable=protected-access + trackable._track_trackable(bias, name="bias") # pylint: disable=protected-access assert len(biases) == len(weights) for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): - cell = checkpointable_lib.Checkpointable() - checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access + cell = trackable_lib.AutoTrackable() + trackable._track_trackable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access cell.bias = bias cell.kernel = kernel @@ -955,6 +955,7 @@ def _cudnn_rnn(inputs, params, is_training, rnn_mode, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -972,6 +973,10 @@ def _cudnn_rnn(inputs, params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be 'linear_input', 'skip_input' or 'auto_select'. @@ -1010,7 +1015,10 @@ def _cudnn_rnn(inputs, "seed2": seed2, "name": name } - if use_cudnn_v2 != "1": + if sequence_lengths is not None: + args["sequence_lengths"] = sequence_lengths + outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) + elif use_cudnn_v2 != "1": outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) else: outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args) @@ -1022,6 +1030,7 @@ def cudnn_lstm(inputs, input_c, params, is_training, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1051,12 +1060,17 @@ def cudnn_lstm(inputs, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. name: name of the operation. Returns: outputs, output_h, output_c """ return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM, - input_mode, direction, dropout, seed, name) + sequence_lengths, input_mode, direction, dropout, seed, + name) def _cudnn_rnn_no_input_c(inputs, @@ -1064,6 +1078,7 @@ def _cudnn_rnn_no_input_c(inputs, params, is_training, rnn_mode, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1079,6 +1094,10 @@ def _cudnn_rnn_no_input_c(inputs, params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be 'linear_input', 'skip_input' or 'auto_select'. @@ -1098,8 +1117,8 @@ def _cudnn_rnn_no_input_c(inputs, """ input_c = array_ops.constant([], dtype=input_h.dtype) outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params, - is_training, rnn_mode, input_mode, - direction, dropout, seed, name) + is_training, rnn_mode, sequence_lengths, + input_mode, direction, dropout, seed, name) return outputs, output_h @@ -1107,6 +1126,7 @@ def cudnn_gru(inputs, input_h, params, is_training, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1129,6 +1149,10 @@ def cudnn_gru(inputs, 'skip_input' is only allowed when input_size == num_units; 'auto_select' implies 'skip_input' when input_size == num_units; otherwise, it implies 'linear_input'. + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. @@ -1139,7 +1163,8 @@ def cudnn_gru(inputs, outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU, - input_mode, direction, dropout, seed, name) + sequence_lengths, input_mode, direction, dropout, + seed, name) def cudnn_rnn_relu(inputs, @@ -1150,6 +1175,7 @@ def cudnn_rnn_relu(inputs, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., seed=0, + sequence_lengths=None, name=None): """Cudnn RNN Relu. @@ -1162,30 +1188,34 @@ def cudnn_rnn_relu(inputs, is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + for behavior. + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. If not + provided, the same sequence length will be assumed. name: name of the operation. + Returns: outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, - CUDNN_RNN_RELU, input_mode, direction, dropout, - seed, name) + CUDNN_RNN_RELU, sequence_lengths, input_mode, + direction, dropout, seed, name) def cudnn_rnn_tanh(inputs, input_h, params, is_training, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1208,6 +1238,10 @@ def cudnn_rnn_tanh(inputs, 'skip_input' is only allowed when input_size == num_units; 'auto_select' implies 'skip_input' when input_size == num_units; otherwise, it implies 'linear_input'. + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. @@ -1218,8 +1252,8 @@ def cudnn_rnn_tanh(inputs, outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, - CUDNN_RNN_TANH, input_mode, direction, dropout, - seed, name) + CUDNN_RNN_TANH, sequence_lengths, input_mode, + direction, dropout, seed, name) def cudnn_rnn_opaque_params_to_canonical(rnn_mode, @@ -1497,7 +1531,13 @@ class _CudnnRNN(object): input_mode=self._input_mode, direction=self._direction) - def __call__(self, input_data, input_h, input_c, params, is_training=True): + def __call__(self, + input_data, + input_h, + input_c, + params, + is_training=True, + sequence_lengths=None): """Runs the forward step for the RNN model. Args: @@ -1509,6 +1549,10 @@ class _CudnnRNN(object): A Tensor of the same shape as input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. + sequence_lengths: an int32 array representing the variable sequence + lengths in a batch. The size of the array has to equal the batch_size. + Default to None, in which case sequences in the batch are assumed to + have the same length, which is inferred from inputs. Returns: output: the output sequence. output_h: the final state for h. @@ -1521,6 +1565,7 @@ class _CudnnRNN(object): params, is_training, self._rnn_mode, + sequence_lengths=sequence_lengths, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, @@ -1615,7 +1660,13 @@ class CudnnLSTM(_CudnnRNN): dropout=dropout, seed=seed) - def __call__(self, input_data, input_h, input_c, params, is_training=True): + def __call__(self, + input_data, + input_h, + input_c, + params, + sequence_lengths=None, + is_training=True): """Runs the forward step for the Cudnn LSTM model. Args: @@ -1626,6 +1677,10 @@ class CudnnLSTM(_CudnnRNN): input_c: the initial hidden state for c. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. + sequence_lengths: an int32 array representing the variable sequence + lengths in a batch. The size of the array has to equal the batch_size. + Default to None, in which case sequences in the batch are assumed to + have the same length, which is inferred from inputs. is_training: whether this operation will be used in training or inference. Returns: output: the output sequence. @@ -1633,7 +1688,12 @@ class CudnnLSTM(_CudnnRNN): output_c: the final state for c. """ output, output_h, output_c = super(CudnnLSTM, self).__call__( - input_data, input_h, input_c, params, is_training=is_training) + input_data, + input_h, + input_c, + params, + sequence_lengths=sequence_lengths, + is_training=is_training) return (output, output_h, output_c) @@ -1687,7 +1747,12 @@ class _CudnnRNNNoInputC(_CudnnRNN): dropout=dropout, seed=seed) - def __call__(self, input_data, input_h, params, is_training=True): + def __call__(self, + input_data, + input_h, + params, + sequence_lengths=None, + is_training=True): """Runs the forward step for the Cudnn LSTM model. Args: @@ -1696,6 +1761,10 @@ class _CudnnRNNNoInputC(_CudnnRNN): input_h: the initial hidden state for h. A Tensor of shape [num_layers, batch_size, num_units]. params: the parameter buffer created for this model. + sequence_lengths: an int32 array representing the variable sequence + lengths in a batch. The size of the array has to equal the batch_size. + Default to None, in which case sequences in the batch are assumed to + have the same length, which is inferred from inputs. is_training: whether this operation will be used in training or inference. Returns: output: the output sequence. @@ -1707,6 +1776,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): params, is_training, self._rnn_mode, + sequence_lengths=sequence_lengths, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py index 6c5f8c6b00975b3fba041271309a93cecd9f5057..4db711c1f3f2815e7b8cf275af315c062ce4c02e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py @@ -25,11 +25,13 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import script_ops from tensorflow.python.platform import test +@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") class AssertElementShapeTest(test_base.DatasetTestBase): def test_assert_element_shape(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index b9840b1ff1a3df5a05db0e64f436637220f49f80..220f9934b67d1d2a97f6c0fd4ba7779f011e1b09 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -27,12 +27,14 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util from tensorflow.python.platform import test from tensorflow.python.util import compat prefix_path = "tensorflow/core/lib" +@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") class LMDBDatasetTest(test_base.DatasetTestBase): def setUp(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py index e7281d531870c75c638b5c48fa3fc6dc606a3623..78019fcc7d810da444f1407f3885d54e76a741c6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py @@ -25,10 +25,12 @@ from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test +@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 2527706709fae8e459aca3489324d4db3c784be6..9275a36582a8c82b936659041129b71e100f883e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -26,11 +26,13 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test +@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index c0152156a1ba70297adb7054622b15ca04f859cd..c6bf5215c9406d03d2704e46903b3aa57e7e68d9 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -389,13 +389,11 @@ class LMDBDataset(dataset_ops.DatasetSource): Args: filenames: A `tf.string` tensor containing one or more filenames. """ - super(LMDBDataset, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") - - def _as_variant_tensor(self): - return gen_experimental_dataset_ops.experimental_lmdb_dataset( + variant_tensor = gen_experimental_dataset_ops.experimental_lmdb_dataset( self._filenames, **dataset_ops.flat_structure(self)) + super(LMDBDataset, self).__init__(variant_tensor) @property def _element_structure(self): diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index 5c6ee6bfdc7167d14b292f8f763adafca4e3a72c..6708e01d08135a132b797e317cd2a241c3428f40 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -30,7 +30,6 @@ class _SlideDataset(dataset_ops.UnaryDataset): def __init__(self, input_dataset, window_size, window_shift, window_stride): """See `sliding_window_batch` for details.""" - super(_SlideDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._window_size = ops.convert_to_tensor( window_size, dtype=dtypes.int64, name="window_stride") @@ -43,14 +42,13 @@ class _SlideDataset(dataset_ops.UnaryDataset): input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) self._structure = input_structure._batch(None) # pylint: disable=protected-access - - def _as_variant_tensor(self): - return ged_ops.experimental_sliding_window_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = ged_ops.experimental_sliding_window_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access window_size=self._window_size, window_shift=self._window_shift, window_stride=self._window_stride, **dataset_ops.flat_structure(self)) + super(_SlideDataset, self).__init__(input_dataset, variant_tensor) @property def _element_structure(self): diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 8a8dc159ade6f2a4a9b5ec29055ea4848492b29f..dbcaf8185fb7a9d2bcf22376439c0ebd49accb1a 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -43,28 +43,19 @@ the workers. Let's see how to scale to multiple GPUs on one machine using `MirroredStrategy` with [tf.keras] (https://www.tensorflow.org/guide/keras). -Take a very simple model consisting of a single layer: +Let's define a simple input dataset for training this model. Note that currently we require using +[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) +with `DistributionStrategy`. ```python import tensorflow as tf from tensorflow import keras -inputs = tf.keras.layers.Input(shape=(1,)) -predictions = tf.keras.layers.Dense(1)(inputs) -model = tf.keras.models.Model(inputs=inputs, outputs=predictions) -``` - -Let's also define a simple input dataset for training this model. Note that currently we require using -[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) -with `DistributionStrategy`. - -```python features = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10) labels = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10) train_dataset = tf.data.Dataset.zip((features, labels)) ``` - To distribute this Keras model on multiple GPUs using `MirroredStrategy` we first instantiate a `MirroredStrategy` object. @@ -72,14 +63,17 @@ first instantiate a `MirroredStrategy` object. distribution = tf.contrib.distribute.MirroredStrategy() ``` -We then compile the Keras model and pass the `MirroredStrategy` object in the -`distribute` argument (apart from other usual arguments like `loss` and -`optimizer`). +Take a very simple model consisting of a single layer. We need to create and compile +the model under the distribution strategy scope. ```python -model.compile(loss='mean_squared_error', - optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2), - distribute=distribution) +with distribution.scope(): + inputs = tf.keras.layers.Input(shape=(1,)) + predictions = tf.keras.layers.Dense(1)(inputs) + model = tf.keras.models.Model(inputs=inputs, outputs=predictions) + + model.compile(loss='mean_squared_error', + optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2)) ``` To train the model we call Keras `fit` API using the input dataset that we diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 8ec73654e30e4967f318c558ba94301e84a206e4..59d76f5d1c817d7f2cc8ad285b9fb517fe994a81 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -30,12 +30,13 @@ from tensorflow.contrib.distribute.python.monitor import Monitor from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * +from tensorflow.contrib.distribute.python.tpu_strategy import initialize_tpu_system from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy from tensorflow.python.distribute.cross_device_ops import * from tensorflow.python.distribute.distribute_config import DistributeConfig from tensorflow.python.distribute.distribute_coordinator import run_standard_tensorflow_server -from tensorflow.python.training.distribute import * -from tensorflow.python.training.distribution_strategy_context import * +from tensorflow.python.distribute.distribute_lib import * +from tensorflow.python.distribute.distribution_strategy_context import * from tensorflow.python.util.all_util import remove_undocumented @@ -58,11 +59,14 @@ _allowed_symbols = [ 'StandardSingleLossStep', 'ReplicaContext', 'TPUStrategy', + 'initialize_tpu_system', 'get_cross_replica_context', 'get_distribution_strategy', 'get_loss_reduction', 'get_replica_context', + 'get_strategy', 'has_distribution_strategy', + 'has_strategy', 'in_cross_replica_context', 'require_replica_context', 'run_standard_tensorflow_server', diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 4c9c35da5a36aa8149d15c8d1c25e4dfaa6a07c1..2ab94d00565376bfebd80ee61094831e09ed3e68 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -1,5 +1,10 @@ # Implementation of a prototype TF distributed computation library. +load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") +load("//tensorflow/core:platform/default/distribute.bzl", "distribute_py_test") +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + package( default_visibility = [ "//tensorflow:internal", @@ -10,11 +15,18 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test") -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - -# TODO(priyag): Figure out testonly issues that are preventing us from -# including our tests in pip for now. +py_library( + name = "distribute_test_lib_pip", + visibility = ["//tensorflow:internal"], + deps = [ + ":combinations", + ":keras_correctness_test_lib", + ":keras_test_lib", + ":multi_worker_test_base", + ":single_loss_example", + ":strategy_test_lib", + ], +) cuda_py_test( name = "values_test", @@ -22,25 +34,36 @@ cuda_py_test( additional_deps = [ ":combinations", ":mirrored_strategy", - ":multi_worker_test_base", "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", - "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:training", "//tensorflow/python:variable_scope", - "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute:device_util", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", ], - tags = [ - "no_pip", +) + +cuda_py_test( + name = "input_lib_test", + srcs = ["input_lib_test.py"], + additional_deps = [ + ":combinations", + ":mirrored_strategy", + ":multi_worker_test_base", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:input_lib", + "//tensorflow/python/distribute:values", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", ], ) @@ -50,8 +73,8 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:mirrored_strategy", - "//tensorflow/python/distribute:values", ], ) @@ -60,18 +83,10 @@ py_library( srcs = ["parameter_server_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":mirrored_strategy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/distribute:cross_device_ops", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:parameter_server_strategy", "//tensorflow/python/distribute:values", - "//tensorflow/python/eager:context", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", ], ) @@ -104,7 +119,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -113,15 +127,17 @@ py_library( srcs = ["one_device_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/distribute:reduce_util", - "//tensorflow/python/distribute:values", - "//tensorflow/python/eager:context", - "@six_archive//:six", + "//tensorflow/python/distribute:one_device_strategy", + ], +) + +cuda_py_test( + name = "one_device_strategy_test", + srcs = ["one_device_strategy_test.py"], + additional_deps = [ + ":strategy_test_lib", + ":combinations", + "//tensorflow/python/eager:test", ], ) @@ -130,28 +146,16 @@ py_library( srcs = ["collective_all_reduce_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":mirrored_strategy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:collective_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python/distribute:cross_device_ops", - "//tensorflow/python/distribute:cross_device_utils", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/distribute:values", - "//tensorflow/python/eager:context", + "//tensorflow/python/distribute:collective_all_reduce_strategy", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", ], ) py_library( name = "strategy_test_lib", - testonly = 1, srcs = ["strategy_test_lib.py"], srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -164,20 +168,18 @@ py_library( "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//third_party/py/numpy", ], ) py_library( name = "combinations", - testonly = 1, srcs = ["combinations.py"], srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], deps = [ ":mirrored_strategy", ":one_device_strategy", + ":parameter_server_strategy", ":tpu_strategy", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/optimizer_v2:training", @@ -186,6 +188,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/eager:context", + "//tensorflow/python/keras/optimizer_v2", "@absl_py//absl/testing:parameterized", ], ) @@ -193,30 +196,12 @@ py_library( py_test( name = "combinations_test", srcs = ["combinations_test.py"], - tags = [ - "no_pip", - ], deps = [ ":combinations", "//tensorflow/python/eager:test", ], ) -py_test( - name = "one_device_strategy_test", - srcs = ["one_device_strategy_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ - ":one_device_strategy", - ":strategy_test_lib", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python/eager:test", - ], -) - # TODO(priyag): Rename this test to mirrored_strategy_test cuda_py_test( name = "mirrored_strategy_multigpu_test", @@ -242,18 +227,13 @@ cuda_py_test( tags = [ "guitar", "multi_and_single_gpu", - "no_pip", ], ) py_library( name = "multi_worker_test_base", - testonly = 1, srcs = ["multi_worker_test_base.py"], srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -288,6 +268,8 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:util", + "//tensorflow/python/distribute:input_lib", + "//tensorflow/python/distribute:numpy_dataset", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:values", ], @@ -320,14 +302,16 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) -py_library( - name = "minimize_loss_test_lib", - testonly = 1, +distribute_py_test( + name = "minimize_loss_test", srcs = ["minimize_loss_test.py"], + main = "minimize_loss_test.py", + tags = [ + "multi_and_single_gpu", + ], deps = [ ":combinations", ":mirrored_strategy", @@ -347,18 +331,6 @@ py_library( ], ) -cuda_py_test( - name = "minimize_loss_test", - srcs = ["minimize_loss_test.py"], - additional_deps = [ - ":minimize_loss_test_lib", - ], - tags = [ - "multi_and_single_gpu", - "no_pip", - ], -) - cuda_py_test( name = "moving_averages_test", srcs = ["moving_averages_test.py"], @@ -372,9 +344,6 @@ cuda_py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], - tags = [ - "no_pip", - ], ) cuda_py_test( @@ -392,7 +361,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -415,7 +383,6 @@ cuda_py_test( tags = [ "multi_and_single_gpu", "no_oss", # http://b/119349471 - "no_pip", "tf_integration_test", ], ) @@ -426,10 +393,10 @@ cuda_py_test( additional_deps = [ ":keras_test_lib", ], + shard_count = 4, tags = [ "multi_and_single_gpu", "no_oss", # http://b/119349471 - "no_pip", "tf_integration_test", ], ) @@ -459,7 +426,6 @@ cuda_py_test( shard_count = 48, tags = [ "multi_and_single_gpu", - "no_pip", # TODO(b/118768923): Re-enable {a,m,t}san test. "noasan", "nomsan", @@ -481,10 +447,13 @@ py_library( ], ) -py_library( - name = "step_fn_test_lib", - testonly = 1, +distribute_py_test( + name = "step_fn_test", srcs = ["step_fn_test.py"], + main = "step_fn_test.py", + tags = [ + "multi_and_single_gpu", + ], deps = [ ":combinations", ":single_loss_example", @@ -497,18 +466,6 @@ py_library( ], ) -cuda_py_test( - name = "step_fn_test", - srcs = ["step_fn_test.py"], - additional_deps = [ - ":step_fn_test_lib", - ], - tags = [ - "multi_and_single_gpu", - "no_pip", - ], -) - py_library( name = "monitor", srcs = ["monitor.py"], @@ -525,10 +482,10 @@ cuda_py_test( additional_deps = [ ":combinations", ":monitor", - ":one_device_strategy", ":single_loss_example", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", + "//tensorflow/python/distribute:one_device_strategy", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python:framework_ops", @@ -536,7 +493,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -553,15 +509,13 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], - tags = [ - "no_pip", - ], ) cuda_py_test( name = "cross_device_ops_test", srcs = ["cross_device_ops_test.py"], additional_deps = [ + ":collective_all_reduce_strategy", ":combinations", ":multi_worker_test_base", ":mirrored_strategy", @@ -577,14 +531,16 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) py_library( name = "keras_test_lib", - testonly = 1, - srcs = ["keras_test.py"], + srcs = [ + "keras_backward_compat_test.py", + "keras_test.py", + "keras_utils_test.py", + ], deps = [ ":combinations", "//tensorflow/contrib/distribute/python:mirrored_strategy", @@ -599,46 +555,199 @@ py_library( ], ) -cuda_py_test( +distribute_py_test( name = "keras_test", srcs = ["keras_test.py"], - additional_deps = [ + full_precision = True, + main = "keras_test.py", + shard_count = 32, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ ":keras_test_lib", ], - shard_count = 16, +) + +distribute_py_test( + name = "keras_utils_test", + srcs = ["keras_utils_test.py"], + full_precision = True, + main = "keras_utils_test.py", + shard_count = 32, tags = [ "multi_and_single_gpu", "no_oss", # TODO(b/117919883): Fix python error. - "no_pip", "no_windows_gpu", "notsan", ], + deps = [ + ":keras_test", + ":keras_test_lib", + ], +) + +# TODO(b/121200287): Remove this in 2.0 +distribute_py_test( + name = "keras_backward_compat_test", + srcs = ["keras_backward_compat_test.py"], + full_precision = True, + main = "keras_backward_compat_test.py", + shard_count = 31, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_test_lib", + ], ) py_library( - name = "metrics_v1_test_lib", - testonly = 1, - srcs = ["metrics_v1_test.py"], + name = "keras_correctness_test_lib", + srcs = [ + "keras_correctness_test_base.py", + "keras_dnn_correctness_test.py", + "keras_embedding_model_correctness_test.py", + "keras_image_model_correctness_test.py", + "keras_lstm_model_correctness_test.py", + "keras_stateful_lstm_model_correctness_test.py", + ], deps = [ ":combinations", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/contrib/distribute/python:tpu_strategy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:training", "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/keras", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -cuda_py_test( - name = "metrics_v1_test", - srcs = ["metrics_v1_test.py"], - additional_deps = [ - ":metrics_v1_test_lib", +distribute_py_test( + name = "keras_dnn_correctness_test", + size = "medium", + srcs = ["keras_dnn_correctness_test.py"], + full_precision = True, + main = "keras_dnn_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 19, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_correctness_test_lib", ], +) + +distribute_py_test( + name = "keras_image_model_correctness_test", + size = "medium", + srcs = ["keras_image_model_correctness_test.py"], + full_precision = True, + main = "keras_image_model_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 31, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_correctness_test_lib", + ], +) + +distribute_py_test( + name = "keras_embedding_model_correctness_test", + size = "medium", + srcs = ["keras_embedding_model_correctness_test.py"], + full_precision = True, + main = "keras_embedding_model_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 31, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_correctness_test_lib", + ], +) + +distribute_py_test( + name = "keras_lstm_model_correctness_test", + size = "medium", + srcs = ["keras_lstm_model_correctness_test.py"], + full_precision = True, + main = "keras_lstm_model_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 31, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_correctness_test_lib", + ], +) + +distribute_py_test( + name = "keras_stateful_lstm_model_correctness_test", + size = "medium", + srcs = ["keras_stateful_lstm_model_correctness_test.py"], + full_precision = True, + main = "keras_stateful_lstm_model_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 31, tags = [ "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. "no_pip", + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_correctness_test_lib", + ], +) + +distribute_py_test( + name = "metrics_v1_test", + srcs = ["metrics_v1_test.py"], + main = "metrics_v1_test.py", + tags = [ + "multi_and_single_gpu", + ], + deps = [ + ":combinations", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:test", + "@absl_py//absl/testing:parameterized", ], ) @@ -656,7 +765,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -667,7 +775,6 @@ cuda_py_test( additional_deps = [ ":combinations", "//tensorflow/python:client_testlib", - "//tensorflow/python:checkpoint_utils_test", "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", @@ -675,6 +782,25 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", + ], +) + +tf_xla_py_test( + name = "checkpointing_test", + srcs = ["checkpointing_test.py"], + disabled_backends = [ + # Only makes sense on TPUs + "cpu", + "gpu", + "cpu_ondemand", + ], + tags = [ + "no_oss", + ], + deps = [ + ":tpu_strategy", + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/python/eager:test", + "//tensorflow/python/training/tracking:util", ], ) diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py index 31bd0e996a247a2fc01405fb3b8172a40853d698..7ee50f03155636a487020d0a9178107a06775588 100644 --- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -25,6 +25,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations @@ -33,7 +34,23 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import checkpoint_utils_test +from tensorflow.python.training import saver as saver_lib + + +def _create_checkpoints(sess, checkpoint_dir): + checkpoint_prefix = os.path.join(checkpoint_dir, "model") + checkpoint_state_name = "checkpoint" + v1 = variable_scope.get_variable("var1", [1, 10]) + v2 = variable_scope.get_variable("var2", [10, 10]) + sess.run(variables.global_variables_initializer()) + v1_value, v2_value = sess.run([v1, v2]) + saver = saver_lib.Saver() + saver.save( + sess, + checkpoint_prefix, + global_step=0, + latest_filename=checkpoint_state_name) + return v1_value, v2_value class CheckpointUtilsWithDistributionStrategyTest( @@ -51,8 +68,7 @@ class CheckpointUtilsWithDistributionStrategyTest( def testInitFromCheckpoint(self, distribution, in_replica_mode): checkpoint_dir = self.get_temp_dir() with self.cached_session() as session: - v1_value, v2_value, _, _ = checkpoint_utils_test._create_checkpoints( - session, checkpoint_dir) + v1_value, v2_value = _create_checkpoints(session, checkpoint_dir) def init_and_verify(g): v1 = variable_scope.get_variable("new_var1", [1, 10]) @@ -71,7 +87,7 @@ class CheckpointUtilsWithDistributionStrategyTest( with ops.Graph().as_default() as g, distribution.scope(): if in_replica_mode: - distribution.call_for_each_replica(init_and_verify, args=[g]) + distribution.extended.call_for_each_replica(init_and_verify, args=[g]) else: init_and_verify(g) diff --git a/tensorflow/contrib/distribute/python/checkpointing_test.py b/tensorflow/contrib/distribute/python/checkpointing_test.py new file mode 100644 index 0000000000000000000000000000000000000000..eadf7233f2ae5ee50b71836ebfcc895163124ac2 --- /dev/null +++ b/tensorflow/contrib/distribute/python/checkpointing_test.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os + +from tensorflow.compiler.tests import xla_test +from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core +from tensorflow.python.platform import test +from tensorflow.python.training import adam as adam_v1 +from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import training_util +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util as trackable_utils + + +class NonLayerTrackable(tracking.AutoTrackable): + + def __init__(self): + super(NonLayerTrackable, self).__init__() + self.a_variable = trackable_utils.add_variable( + self, name="a_variable", shape=[]) + + +class Subclassed(training.Model): + """A concrete Model for testing.""" + + def __init__(self): + super(Subclassed, self).__init__() + self._named_dense = core.Dense(1, use_bias=True) + self._second = core.Dense(1, use_bias=False) + # We can still track Trackables which aren't Layers. + self._non_layer = NonLayerTrackable() + + def call(self, values): + ret = self._second(self._named_dense(values)) + return ret + + +class TrainingCheckpointTests(xla_test.XLATestCase): + + def testEagerTPUDistributionStrategy(self): + self.skipTest("b/121387144") + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + def _train_fn(optimizer, model): + input_value = constant_op.constant([[3.]]) + optimizer.minimize( + functools.partial(model, input_value), + global_step=root.optimizer_step) + + for training_continuation in range(3): + strategy = tpu_strategy.TPUStrategy() + with strategy.scope(): + model = Subclassed() + optimizer = adam_v1.AdamOptimizer(0.001) + root = trackable_utils.Checkpoint( + optimizer=optimizer, model=model, + optimizer_step=training_util.get_or_create_global_step()) + root.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) + + for _ in range(num_training_steps): + strategy.extended.call_for_each_replica( + functools.partial(_train_fn, optimizer, model)) + root.save(file_prefix=checkpoint_prefix) + self.assertEqual((training_continuation + 1) * num_training_steps, + root.optimizer_step.numpy()) + + +if __name__ == "__main__": + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 5c50a20490482856becedf7b1379d2a0583d9a11..19741627980c34d8c281f7aed6f1464d4a03393e 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -18,27 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy - -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib -from tensorflow.python.distribute import cross_device_utils -from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.distribute import values -from tensorflow.python.eager import context -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import collective_ops -from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver # TODO(yuefengz): support in-graph replication. class CollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): """Distribution strategy that uses collective ops for all-reduce. + *** contrib version *** + It is similar to the MirroredStrategy but it uses collective ops for reduction. @@ -61,276 +52,19 @@ class CollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): CollectiveAllReduceExtended(self, num_gpus_per_worker)) -class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): +class CollectiveAllReduceExtended( + collective_all_reduce_strategy.CollectiveAllReduceExtended): """Implementation of CollectiveAllReduceStrategy.""" def __init__(self, container_strategy, num_gpus_per_worker): - distribute_lib.DistributionStrategyExtended.__init__( - self, container_strategy) - self._cross_device_ops = None - self._num_gpus_per_worker = num_gpus_per_worker - self._initialize_local_worker(num_gpus_per_worker) - - def _initialize_local_worker(self, num_gpus_per_worker): - """Initializes the object for local training.""" - self._is_chief = True - self._num_workers = 1 - - if num_gpus_per_worker: - local_devices = tuple( - "/device:GPU:%d" % i for i in range(num_gpus_per_worker) - ) - else: - local_devices = ("/device:CPU:0",) - self._worker_device = device_util.canonicalize("/device:CPU:0") - - self._collective_keys = cross_device_utils.CollectiveKeys() - self._initialize_local(local_devices) - self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( - num_workers=self._num_workers, - num_gpus_per_worker=num_gpus_per_worker, - collective_keys=self._collective_keys) - - self._cluster_spec = None - self._task_type = None - self._task_id = None - - logging.info("CollectiveAllReduceStrategy with local_devices = %r", - local_devices) - - def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec, - task_type, task_id): - """Initializes the object for multi-worker training.""" - if task_type is None or task_id is None: - raise ValueError("When `cluster_spec` is given, you must also specify " - "`task_type` and `task_id`") - if task_type not in ("chief", "worker"): - raise ValueError( - "Unrecognized task_type: %r, valid task types are: \"chief\", " - "\"worker\"." % task_type) - cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) - if not self._num_workers: - raise ValueError("No `worker` or `chief` tasks can be found in " - "`cluster_spec`.") - - self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, - task_id) - - self._worker_device = "/job:%s/task:%d" % (task_type, task_id) - if num_gpus_per_worker: - local_devices = tuple( - "%s/device:GPU:%d" % (self._worker_device, i) - for i in range(num_gpus_per_worker) - ) - else: - local_devices = (self._worker_device,) - - self._collective_keys = cross_device_utils.CollectiveKeys() - self._initialize_local(local_devices) - self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( - num_workers=self._num_workers, - num_gpus_per_worker=num_gpus_per_worker, - collective_keys=self._collective_keys) - - # Add a default device so that ops without specified devices will not end up - # on other workers. - self._default_device = "/job:%s/task:%d" % (task_type, task_id) - - self._cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._task_type = task_type - self._task_id = task_id - - logging.info( - "Multi-worker CollectiveAllReduceStrategy with " - "cluster_spec = %r, task_type = %r, task_id = %r, " - "num_workers = %r, local_devices = %r", cluster_spec.as_dict(), - task_type, task_id, self._num_workers, local_devices) - - def _create_variable(self, next_creator, *args, **kwargs): - colocate_with = kwargs.pop("colocate_with", None) - devices = self._get_devices_from(colocate_with) - group_size = len(devices) * self._num_workers - group_key = self._collective_keys.get_group_key(self._devices) - - def _real_mirrored_creator(devices, *args, **kwargs): - """Creates one MirroredVariable on the current worker.""" - index = {} - unique_var_name = ops.get_default_graph().unique_name( - kwargs["name"], mark_as_used=False).rstrip("/") - collective_instance_key = self._collective_keys.get_instance_key( - key_id=unique_var_name) - if "initial_value" not in kwargs: - raise ValueError("Initial value must be specified.") - initial_value = kwargs["initial_value"] - if callable(initial_value): - initial_value_fn = initial_value - else: - initial_value_fn = lambda: initial_value - - for i, d in enumerate(devices): - with ops.device(d): - if i > 0: - # Give replicas meaningful distinct names: - var0name = index[devices[0]].name.split(":")[0] - # We append a / to variable names created on replicas with id > 0 to - # ensure that we ignore the name scope and instead use the given - # name as the absolute name of the variable. - kwargs["name"] = "%s/replica_%d/" % (var0name, i) - - # The initial value fn makes sure variables all initialized to - # same values. The first device of the chief worker will send their - # variable values to other devices and other workers. - def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring - with ops.device(device): - initial_value = initial_value_fn() - assert not callable(initial_value) - initial_value = ops.convert_to_tensor(initial_value) - - if self._is_chief and index == 0: - bcast_send = collective_ops.broadcast_send( - initial_value, initial_value.shape, initial_value.dtype, - group_size, group_key, collective_instance_key) - with ops.control_dependencies([bcast_send]): - return array_ops.identity(initial_value) - else: - return collective_ops.broadcast_recv( - initial_value.shape, initial_value.dtype, group_size, - group_key, collective_instance_key) - - kwargs["initial_value"] = _overridden_initial_value_fn - - with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - v = next_creator(*args, **kwargs) - - if i == 0: - actual_var_name = v.name.split(":")[0] - assert unique_var_name == actual_var_name, "%r vs %r" % ( - unique_var_name, actual_var_name) - assert not isinstance(v, values.DistributedVariable) - index[d] = v - return index - - # pylint: disable=protected-access - return mirrored_strategy._create_mirrored_variable( - devices, _real_mirrored_creator, *args, **kwargs) - - def _distribute_dataset(self, dataset_fn): - """Distributes the dataset to each local GPU.""" - # TODO(yuefengz): shard the dataset. - return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._devices, True) - - def _make_dataset_iterator(self, dataset): - worker_device_pairs = [(self._worker_device, self._devices)] - return values.DatasetIterator(dataset, worker_device_pairs, - self._num_replicas_in_sync) - - def _make_input_fn_iterator( - self, - input_fn, - replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - """Distributes the dataset to each local GPU.""" - if self._cluster_spec is None: - input_pipeline_id = 0 - else: - input_pipeline_id = multi_worker_util.id_in_cluster( - self._cluster_spec, self._task_type, self._task_id) - input_context = distribute_lib.InputContext( - num_input_pipelines=self._num_workers, - input_pipeline_id=input_pipeline_id, - num_replicas_in_sync=self._num_replicas_in_sync) - - return values.InputFunctionIterator( - input_fn, [(self._worker_device, self._devices)], [input_context]) - - def _configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): - """Configures the object. - - Args: - session_config: a `tf.ConfigProto` - cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the - cluster configurations. - task_type: the current task type, such as "worker". - task_id: the current task id. - - Raises: - ValueError: if `task_type` is not in the `cluster_spec`. - """ - if not self._cluster_spec and cluster_spec: - # If a `cluster_spec` is already passed in, do nothing here. - # TODO(yuefengz): check `cluster_spec` is the same if this object has - # already been initialized with a `cluster_spec`. - self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec, - task_type, task_id) - - if session_config: - session_config.CopyFrom(self._update_config_proto(session_config)) - - def _update_config_proto(self, config_proto): - updated_config = copy.deepcopy(config_proto) - # Enable the scoped allocator optimization for CollectiveOps. This - # optimization converts many small all-reduces into fewer larger - # all-reduces. - rewrite_options = updated_config.graph_options.rewrite_options - rewrite_options.scoped_allocator_optimization = ( - rewriter_config_pb2.RewriterConfig.ON) - # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op = - # ["CollectiveReduce"]. Since we can't assign to a repeated proto field, we - # clear and then append. - del rewrite_options.scoped_allocator_opts.enable_op[:] - rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce") - - if not self._cluster_spec: - return updated_config - - assert self._task_type - assert self._task_id is not None - - # Collective group leader is needed for collective ops to coordinate - # workers. - if "chief" in self._cluster_spec.jobs: - updated_config.experimental.collective_group_leader = ( - "/job:chief/replica:0/task:0") - else: - if "worker" not in self._cluster_spec.jobs: - raise ValueError( - "You must have `chief` or `worker` jobs in the `cluster_spec`.") - updated_config.experimental.collective_group_leader = ( - "/job:worker/replica:0/task:0") - - # The device filters prevent communication between workers. - del updated_config.device_filters[:] - updated_config.device_filters.append( - "/job:%s/task:%d" % (self._task_type, self._task_id)) - - return updated_config - - @property - def experimental_between_graph(self): - return True - - @property - def experimental_should_init(self): - return True - - @property - def should_checkpoint(self): - return self._is_chief - - @property - def should_save_summary(self): - return self._is_chief - - @property - def _num_replicas_in_sync(self): - return len(self._devices) * self._num_workers - - # TODO(priyag): Delete this once all strategies use global batch size. - @property - def _global_batch_size(self): - return False + # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change + # the constructor's interface to allow customized cluster resolver. Use + # SimpleClusterResolver to override num_accelerators. + tfconfig = TFConfigClusterResolver() + cluster_resolver = SimpleClusterResolver( + cluster_spec=tfconfig.cluster_spec(), + task_type=tfconfig.task_type, + task_id=tfconfig.task_id, + num_accelerators=num_gpus_per_worker) + super(CollectiveAllReduceExtended, self).__init__( + container_strategy, cluster_resolver=cluster_resolver) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 8a9e583f0afaac37a2057bae9b1ed79de43d68bc..ee7640dd1cea15e62ae9912ebedbd853778364a6 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -29,9 +29,13 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import collective_all_reduce_strategy as core_collective_all_reduce_strategy from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -49,6 +53,55 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.training import adam from tensorflow.python.training import training_util +from tensorflow.python.training.server_lib import ClusterSpec + + +class MockCollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): + """Mock the strategy to allow cluster resolver as an argument.""" + + def __init__(self, cluster_resolver): + super(MockCollectiveAllReduceStrategy, self).__init__( + core_collective_all_reduce_strategy.CollectiveAllReduceExtended( + self, cluster_resolver=cluster_resolver)) + + +def create_test_objects(cluster_spec=None, + task_type=None, + task_id=None, + num_gpus=None, + use_core_strategy=False): + sess_config = config_pb2.ConfigProto() + if num_gpus is None: + num_gpus = context.num_gpus() + if use_core_strategy: + if cluster_spec and task_type and task_id is not None: + cluster_resolver = SimpleClusterResolver( + cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), + task_type=task_type, + task_id=task_id, + num_accelerators=num_gpus) + target = 'grpc://' + cluster_spec[task_type][task_id] + else: + cluster_resolver = SimpleClusterResolver( + ClusterSpec({}), num_accelerators=num_gpus) + target = '' + + strategy = MockCollectiveAllReduceStrategy(cluster_resolver) + sess_config = strategy.update_config_proto(sess_config) + else: + strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + if task_type and task_id is not None: + strategy.configure( + session_config=sess_config, + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id) + target = 'grpc://' + cluster_spec[task_type][task_id] + else: + target = '' + + return strategy, target, sess_config class CollectiveAllReduceStrategyTestBase( @@ -64,16 +117,18 @@ class CollectiveAllReduceStrategyTestBase( CollectiveAllReduceStrategyTestBase.collective_key_base += 100000 super(CollectiveAllReduceStrategyTestBase, self).setUp() - def _get_test_object(self, task_type, task_id, num_gpus=0): - distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - num_gpus_per_worker=num_gpus) - session_config = config_pb2.ConfigProto() - if task_type and task_id is not None: - distribution.configure( - session_config=session_config, - cluster_spec=self._cluster_spec, - task_type=task_type, - task_id=task_id) + def _get_test_object(self, + task_type, + task_id, + num_gpus=0, + use_core_strategy=False): + strategy, target, session_config = create_test_objects( + cluster_spec=self._cluster_spec, + task_type=task_type, + task_id=task_id, + num_gpus=num_gpus, + use_core_strategy=use_core_strategy) + collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 * num_gpus + CollectiveAllReduceStrategyTestBase.collective_key_base, @@ -81,16 +136,16 @@ class CollectiveAllReduceStrategyTestBase( CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_with_id_start=num_gpus * 10000 + CollectiveAllReduceStrategyTestBase.collective_key_base) - distribution.extended._collective_keys = collective_keys - distribution.extended._inferred_cross_device_ops._collective_keys = ( - collective_keys) - if task_type and task_id is not None: - return distribution, 'grpc://' + self._cluster_spec[task_type][ - task_id], session_config - else: - return distribution, '', session_config + strategy.extended._collective_keys = collective_keys + strategy.extended._cross_device_ops._collective_keys = (collective_keys) - def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + return strategy, target, session_config + + def _test_minimize_loss_graph(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): d, master_target, config = self._get_test_object(task_type, task_id, num_gpus) with ops.Graph().as_default(), \ @@ -123,20 +178,20 @@ class CollectiveAllReduceStrategyTestBase( def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, args=[one]) + g_v = d.extended.call_for_each_replica(grad_fn, args=[one]) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] for g, v in g_v: - fetched = d.read_var(v) + fetched = d.extended.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.extended.reduce_to( reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( - d.update(v, update, g, grouped=False)): - after_list.append(d.read_var(v)) + d.extended.update(v, update, args=(g,), group=False)): + after_list.append(d.extended.read_var(v)) return before_list, after_list before_out, after_out = step() @@ -158,7 +213,11 @@ class CollectiveAllReduceStrategyTestBase( self.assertLess(error_after, error_before) return error_after < error_before - def _test_complex_model(self, task_type, task_id, num_gpus): + def _test_complex_model(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): d, master_target, config = self._get_test_object(task_type, task_id, num_gpus) @@ -192,6 +251,7 @@ class CollectiveAllReduceStrategyTestBase( image = random_ops.random_uniform([2, 28, 28]) label = random_ops.random_uniform([2, 1], maxval=10, dtype=dtypes.int32) logits = model(image, training=True) + # TODO(yuefengz): make loss a callable for eager mode. loss = losses.sparse_softmax_cross_entropy(labels=label, logits=logits) optimizer = adam.AdamOptimizer(learning_rate=1e-4) train_op = optimizer.minimize(loss, @@ -202,14 +262,18 @@ class CollectiveAllReduceStrategyTestBase( self.cached_session(config=config, target=master_target) as sess: with d.scope(): - train_op = d.call_for_each_replica(model_fn) + train_op = d.extended.call_for_each_replica(model_fn) train_op = d.group(d.unwrap(train_op)) sess.run(variables.global_variables_initializer()) sess.run(train_op) return True - def _test_variable_initialization(self, task_type, task_id, num_gpus): + def _test_variable_initialization(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): distribution, master_target, config = self._get_test_object( task_type, task_id, num_gpus) with ops.Graph().as_default(), \ @@ -225,7 +289,7 @@ class CollectiveAllReduceStrategyTestBase( 1.0, 10.0, dtype=dtypes.float32)) return array_ops.identity(x) - x = distribution.call_for_each_replica(model_fn) + x = distribution.extended.call_for_each_replica(model_fn) reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x) x = distribution.unwrap(x)[0] @@ -238,8 +302,14 @@ class CollectiveAllReduceStrategyTestBase( reduced_x_value))) return np.allclose(x_value, reduced_x_value, atol=1e-5) - def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, - expected_values): + def _test_input_fn_iterator(self, + task_type, + task_id, + num_gpus, + input_fn, + expected_values, + test_reinitialize=True, + use_core_strategy=False): distribution, master_target, config = self._get_test_object( task_type, task_id, num_gpus) devices = distribution.extended.worker_devices @@ -252,22 +322,24 @@ class CollectiveAllReduceStrategyTestBase( for expected_value in expected_values: next_element = iterator.get_next() - computed_value = sess.run( - [values.select_device(d, next_element) for d in devices]) + computed_value = sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - sess.run([values.select_device(d, next_element) for d in devices]) + sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. - sess.run(iterator.initialize()) + if test_reinitialize: + sess.run(iterator.initialize()) - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = sess.run( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, computed_value) + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) + self.assertEqual(expected_value, computed_value) class DistributedCollectiveAllReduceStrategyTest( @@ -281,71 +353,116 @@ class DistributedCollectiveAllReduceStrategyTest( cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=0) - def test_num_replicas_in_sync(self): - distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - num_gpus_per_worker=2) - distribution.configure(cluster_spec=self._cluster_spec, task_type='worker', - task_id=0) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def test_num_replicas_in_sync(self, use_core_strategy): + distribution, _, _ = create_test_objects( + cluster_spec=self._cluster_spec, + task_type='worker', + task_id=0, + num_gpus=2, + use_core_strategy=use_core_strategy) num_workers = len(self._cluster_spec.get('chief', []) + self._cluster_spec.get('worker', [])) self.assertEqual(2 * num_workers, distribution.num_replicas_in_sync) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) - def testMinimizeLossGraph(self, num_gpus): - self._run_between_graph_clients(self._test_minimize_loss_graph, - self._cluster_spec, num_gpus) + combinations.combine( + mode=['graph'], + num_gpus=[0, 1, 2], + required_gpus=1, + use_core_strategy=[True, False])) + def testMinimizeLossGraph(self, num_gpus, use_core_strategy): + self._run_between_graph_clients( + self._test_minimize_loss_graph, + self._cluster_spec, + num_gpus, + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) - def testVariableInitialization(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[0, 1, 2], + required_gpus=1, + use_core_strategy=[True, False])) + def testVariableInitialization(self, num_gpus, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') self._run_between_graph_clients( self._test_variable_initialization, self._cluster_spec, - num_gpus=num_gpus) + num_gpus=num_gpus, + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) - def testComplexModel(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[0, 1, 2], + required_gpus=1, + use_core_strategy=[True, False])) + def testComplexModel(self, num_gpus, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') self._run_between_graph_clients( - self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) + self._test_complex_model, + self._cluster_spec, + num_gpus=num_gpus, + use_core_strategy=use_core_strategy) + # TODO(b/124344198): Re-enable after fixing this flaky test. # TODO(yuefengz): Update how we use num_gpus and required_gpus @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) - def testMakeInputFnIterator(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[0, 1, 2], + required_gpus=1, + use_dataset=[True, False], + use_core_strategy=[True, False])) + def DISABLED_testMakeInputFnIterator(self, num_gpus, use_dataset, + use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - dataset_fn = lambda: dataset_ops.Dataset.range(100) + if use_dataset: + fn = lambda: dataset_ops.Dataset.range(100) + else: + def fn(): + dataset = dataset_ops.Dataset.range(100) + it = dataset.make_one_shot_iterator() + return it.get_next # We use CPU as the device when num_gpus = 0 devices_per_worker = max(1, num_gpus) expected_values = [[i+j for j in range(devices_per_worker)] for i in range(0, 100, devices_per_worker)] input_fn = self._input_fn_to_test_input_context( - dataset_fn, + fn, expected_num_replicas_in_sync=3*devices_per_worker, expected_num_input_pipelines=3, expected_input_pipeline_id=1) # because task_id = 1 - self._test_input_fn_iterator('worker', 1, num_gpus, - input_fn, expected_values) + self._test_input_fn_iterator( + 'worker', + 1, + num_gpus, + input_fn, + expected_values, + test_reinitialize=use_dataset, + use_core_strategy=use_core_strategy) - def testUpdateConfigProto(self): - distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - num_gpus_per_worker=2) - distribution.configure( - cluster_spec=self._cluster_spec, task_type='worker', task_id=1) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testUpdateConfigProto(self, use_core_strategy): + strategy, _, _ = self._get_test_object( + task_type='worker', + task_id=1, + num_gpus=2, + use_core_strategy=use_core_strategy) config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) rewrite_options = config_proto.graph_options.rewrite_options rewrite_options.scoped_allocator_opts.enable_op.append('to_be_removed') - new_config = distribution.update_config_proto(config_proto) + new_config = strategy.update_config_proto(config_proto) # Verify group leader self.assertEqual('/job:worker/replica:0/task:0', @@ -396,36 +513,136 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) -class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase, - strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class LocalCollectiveAllReduceStrategy( + CollectiveAllReduceStrategyTestBase, + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): - def testMinimizeLossGraph(self, num_gpus=2): + @combinations.generate( + combinations.combine( + mode=['graph', 'eager'], + num_gpus=[2, 4], + required_gpus=2, + use_core_strategy=[True, False])) + def testMinimizeLoss(self, num_gpus, use_core_strategy): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - self._test_minimize_loss_graph(None, None, num_gpus) + if context.executing_eagerly(): + strategy, _, _ = self._get_test_object( + None, None, num_gpus, use_core_strategy=use_core_strategy) + self._test_minimize_loss_eager(strategy) + else: + self._test_minimize_loss_graph( + None, None, num_gpus, use_core_strategy=use_core_strategy) - def testComplexModel(self, num_gpus=2): - # Collective ops doesn't support strategy with one device. + @combinations.generate( + combinations.combine( + mode=['graph'], + num_gpus=[2, 4], + required_gpus=2, + use_core_strategy=[True, False])) + def testComplexModel(self, num_gpus, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - self._test_complex_model(None, None, num_gpus) + self._test_complex_model( + None, None, num_gpus, use_core_strategy=use_core_strategy) - def testMakeInputFnIterator(self, num_gpus=2): - # Collective ops doesn't support strategy with one device. - if context.num_gpus() < num_gpus: - self.skipTest('Not enough GPUs') - dataset_fn = lambda: dataset_ops.Dataset.range(10) - expected_values = [[i, i+1] for i in range(0, 10, 2)] + @combinations.generate( + combinations.combine( + mode=['graph', 'eager'], + required_gpus=2, + use_dataset=[True, False], + use_core_strategy=[True, False])) + def DISABLED_testMakeInputFnIterator(self, use_dataset, use_core_strategy): + num_gpus = 2 + if use_dataset: + fn = lambda: dataset_ops.Dataset.range(5 * num_gpus) + else: + def fn(): + dataset = dataset_ops.Dataset.range(5 * num_gpus) + it = dataset.make_one_shot_iterator() + return it.get_next + expected_values = [range(i, i + num_gpus) for i in range(0, 10, num_gpus)] input_fn = self._input_fn_to_test_input_context( - dataset_fn, + fn, expected_num_replicas_in_sync=num_gpus, expected_num_input_pipelines=1, expected_input_pipeline_id=0) - self._test_input_fn_iterator(None, None, num_gpus, - input_fn, expected_values) + self._test_input_fn_iterator( + None, + None, + num_gpus, + input_fn, + expected_values, + test_reinitialize=use_dataset, + use_core_strategy=use_core_strategy) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceSum(self, use_core_strategy): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum(distribution) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceSumGradients(self, use_core_strategy): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum_gradients(distribution) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceSumGradientTape(self, use_core_strategy): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum_gradient_tape(distribution) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceMean(self, use_core_strategy): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean(distribution) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceMeanGradients(self, use_core_strategy): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean_gradients(distribution) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceMeanGradientTape(self, use_core_strategy): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean_gradient_tape(distribution) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testNumpyIterator(self, use_core_strategy): + num_gpus = 2 + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + strategy, _, _ = self._get_test_object( + None, None, num_gpus=num_gpus, use_core_strategy=use_core_strategy) + self._test_numpy_iterator(strategy) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 365ce5cdec79f1914f0c9ccdf59a7dc59e6f819e..7c0f8033fbc046580bc46f90ee9945ffa2a718f9 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -46,16 +46,22 @@ import unittest from absl.testing import parameterized import six -from tensorflow.contrib.cluster_resolver import TPUClusterResolver +from tensorflow.contrib import cluster_resolver from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib +from tensorflow.contrib.distribute.python import parameter_server_strategy from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2 from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 +from tensorflow.contrib.tpu.python.tpu import device_assignment as device_assignment_lib from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.framework import ops +from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_keras_v2 +from tensorflow.python.keras.optimizer_v2 import adam as adam_keras_v2 +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras_v2 +from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_keras_v2 from tensorflow.python.training import adagrad from tensorflow.python.training import adam from tensorflow.python.training import gradient_descent @@ -226,7 +232,7 @@ def combine(**kwargs): if not kwargs: return [OrderedDict()] - sort_by_key = lambda k: k[0][0] + sort_by_key = lambda k: k[0] kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key)) first = list(kwargs.items())[0] @@ -321,22 +327,49 @@ class NamedDistribution(object): return self._required_tpu +def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs): + def _create_tpu_strategy(): + resolver = cluster_resolver.TPUClusterResolver("") + topology = tpu_lib.initialize_tpu_system(resolver) + device_assignment = None + if use_single_core: + device_assignment = device_assignment_lib.DeviceAssignment( + topology, core_assignment=device_assignment_lib. + SINGLE_CORE_ASSIGNMENT) + + strategy = tpu_lib.TPUStrategy(resolver, steps_per_run=steps_per_run, + device_assignment=device_assignment, + **kwargs) + return strategy + return _create_tpu_strategy + + # pylint: disable=g-long-lambda default_strategy = NamedDistribution( "Default", - distribution_strategy_context._get_default_distribution_strategy, # pylint: disable=protected-access + distribution_strategy_context._get_default_strategy, # pylint: disable=protected-access required_gpus=None) one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) +one_device_strategy_gpu = NamedDistribution( + "OneDeviceGPU", lambda: one_device_lib.OneDeviceStrategy("/gpu:0"), + required_gpus=1) tpu_strategy = NamedDistribution( - "TPU", lambda: tpu_lib.TPUStrategy( - TPUClusterResolver(""), steps_per_run=2), + "TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True) tpu_strategy_one_step = NamedDistribution( - "TPUOneStep", lambda: tpu_lib.TPUStrategy( - TPUClusterResolver(""), steps_per_run=1), + "TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), + required_tpu=True) +tpu_strategy_one_core = NamedDistribution( + "TPUOneCore", _get_tpu_strategy_creator( + steps_per_run=2, use_single_core=True), required_tpu=True) +tpu_strategy_one_step_one_core = NamedDistribution( + "TPUOneStepOneCore", _get_tpu_strategy_creator( + steps_per_run=1, use_single_core=True), + required_tpu=True) + mirrored_strategy_with_one_cpu = NamedDistribution( "Mirrored1CPU", lambda: mirrored_lib.MirroredStrategy(["/cpu:0"])) @@ -367,6 +400,11 @@ core_mirrored_strategy_with_two_gpus = NamedDistribution( "CoreMirrored2GPUs", lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0", "/gpu:1"]), required_gpus=2) +parameter_server_strategy_with_two_gpus = NamedDistribution( + "ParameterServer2GPUs", + lambda: parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2), + required_gpus=2) gradient_descent_optimizer_v1_fn = NamedObject( @@ -386,10 +424,20 @@ gradient_descent_optimizer_v2_fn = NamedObject( adagrad_optimizer_v2_fn = NamedObject( "AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001)) adam_optimizer_v2_fn = NamedObject( - "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1)) + "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1.0)) optimizers_v2 = [gradient_descent_optimizer_v2_fn, adagrad_optimizer_v2_fn] +gradient_descent_optimizer_keras_v2_fn = NamedObject( + "GradientDescentKerasV2", + lambda: gradient_descent_keras_v2.SGD(0.2)) +adagrad_optimizer_keras_v2_fn = NamedObject( + "AdagradKerasV2", lambda: adagrad_keras_v2.Adagrad(0.001)) +adam_optimizer_keras_v2_fn = NamedObject( + "AdamKerasV2", lambda: adam_keras_v2.Adam(0.001, epsilon=1.0)) +rmsprop_optimizer_keras_v2_fn = NamedObject( + "RmsPropKerasV2", lambda: rmsprop_keras_v2.RMSprop(0.001)) + graph_and_eager_modes = ["graph", "eager"] diff --git a/tensorflow/contrib/distribute/python/combinations_test.py b/tensorflow/contrib/distribute/python/combinations_test.py index 86aa48cea889c6c2ce169b18bcabb6d08890fbed..9f3deadbec98c4f66061ca29b4d29a74b8de40b1 100644 --- a/tensorflow/contrib/distribute/python/combinations_test.py +++ b/tensorflow/contrib/distribute/python/combinations_test.py @@ -42,6 +42,14 @@ class TestingCombinationsTest(test.TestCase): "b": 3 }], combinations.combine(a=[1, 2], b=[2, 3])) + def test_arguments_sorted(self): + self.assertEqual([ + OrderedDict([("aa", 1), ("ab", 2)]), + OrderedDict([("aa", 1), ("ab", 3)]), + OrderedDict([("aa", 2), ("ab", 2)]), + OrderedDict([("aa", 2), ("ab", 3)]) + ], combinations.combine(ab=[2, 3], aa=[1, 2])) + def test_combine_single_parameter(self): self.assertEqual([{ "a": 1, diff --git a/tensorflow/contrib/distribute/python/cross_device_ops_test.py b/tensorflow/contrib/distribute/python/cross_device_ops_test.py index d6e9521c1c1115ffdbdcf375ad4017bacb962832..2b8e0197961ae37b67dc8958054a03e164242dcd 100644 --- a/tensorflow/contrib/distribute/python/cross_device_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_ops_test.py @@ -23,6 +23,7 @@ import itertools from absl.testing import parameterized import numpy as np +from tensorflow.contrib.distribute.python import collective_all_reduce_strategy from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base @@ -40,8 +41,16 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +def _get_devices(devices): + if isinstance(devices, (tuple, list)): + return tuple(device_util.resolve(d) for d in devices) + elif isinstance(devices, value_lib.DistributedValues): + return devices.devices + return (device_util.resolve(devices),) + + def _make_per_replica(values, devices, regroup=False): - devices = cross_device_ops_lib.get_devices_from(devices) + devices = _get_devices(devices) assert len(values) == len(devices) # We simulate the result of regroup called on PerReplica which strips the @@ -51,12 +60,12 @@ def _make_per_replica(values, devices, regroup=False): placed_v = array_ops.identity(values[0]) return placed_v - index = {} + index = [] for d, v in zip(devices, values): with ops.device(d): placed_v = array_ops.identity(v) - index[d] = placed_v - return value_lib.PerReplica(index) + index.append(placed_v) + return value_lib.PerReplica(value_lib.ReplicaDeviceMap(devices), index) # pylint: disable=g-doc-args,g-doc-return-or-yield @@ -66,9 +75,9 @@ def _fake_mirrored(value, devices): All components of the returned Mirrored have the same objects, which is not true in reality. """ - devices = cross_device_ops_lib.get_devices_from(devices) - return value_lib.Mirrored( - {d: v for d, v in zip(devices, [value] * len(devices))}) + devices = _get_devices(devices) + return value_lib.Mirrored(value_lib.ReplicaDeviceMap(devices), + [value] * len(devices)) def _make_indexed_slices(values, indices, dense_shape, device): @@ -81,9 +90,9 @@ def _make_indexed_slices(values, indices, dense_shape, device): def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): - return value_lib.Mirrored({ - d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices - }) + values = [_make_indexed_slices(values, indices, dense_shape, d) + for d in devices] + return value_lib.Mirrored(value_lib.ReplicaDeviceMap(devices), values) _cpu_device = "/device:CPU:0" @@ -107,16 +116,16 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): else: self.assertEqual(type(left), type(right)) self.assertEqual(set(left.devices), set(right.devices)) - if isinstance(list(left._index.values())[0], ops.IndexedSlices): - for (d, v) in left._index.items(): - self._assert_indexed_slices_equal(v, right._index[d]) + if isinstance(left.values[0], ops.IndexedSlices): + for d in left.devices: + self._assert_indexed_slices_equal(left.get(d), right.get(d)) elif context.executing_eagerly(): - self.assertEqual([v.numpy() for v in left._index.values()], - list(right._index.values())) + self.assertEqual([v.numpy() for v in left.values], + list(right.values)) else: with self.cached_session() as sess: self.assertEqual( - sess.run(list(left._index.values())), list(right._index.values())) + sess.run(list(left.values)), list(right.values)) def _testReductionAndBroadcast(self, cross_device_ops, distribution): devices = distribution.extended.worker_devices @@ -196,15 +205,15 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): reduction_to_one_combinations = combinations.combine( cross_device_ops=[ combinations.NamedObject( - "DefaultReductionToOneDeviceCrossDeviceOps", - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()), + "DefaultReductionToOneDevice", + cross_device_ops_lib.ReductionToOneDevice()), combinations.NamedObject( "ReductionToCPUDeviceCrossDeviceOps", - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( + cross_device_ops_lib.ReductionToOneDevice( reduce_to_device=_cpu_device)), combinations.NamedObject( "AccumulateNCrossDeviceOp", - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( + cross_device_ops_lib.ReductionToOneDevice( accumulation_fn=math_ops.accumulate_n)), ], distribution=[ @@ -220,20 +229,23 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): combinations.NamedObject( "AllReduce", cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)), - combinations.NamedObject( - "HierarchicalCopy", - cross_device_ops_lib.AllReduceCrossDeviceOps( - "hierarchical_copy", 8, 0, 0)), combinations.NamedObject( "AllReduceNoGradientRepacking", cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)), + combinations.NamedObject("NcclAllReduce", + cross_device_ops_lib.NcclAllReduce()), + combinations.NamedObject( + "HierarchicalCopy", + cross_device_ops_lib.HierarchicalCopyAllReduce(8)), combinations.NamedObject( "HierarchicalCopyAggregateSmallTensors", cross_device_ops_lib.AllReduceCrossDeviceOps( "hierarchical_copy", 0, 100, 10)) ], - distribution=[combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus + ], mode=["graph", "eager"]) @combinations.generate(reduction_to_one_combinations + allreduce_combinations) @@ -280,7 +292,8 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): devices = ["/cpu:0", "/gpu:0"] t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) - per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) + per_replica = value_lib.PerReplica( + value_lib.ReplicaDeviceMap(devices), (t0, t1)) result = cross_device_ops_lib._simple_reduce( per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM) @@ -297,8 +310,8 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): combinations.combine( cross_device_ops_instance=[ combinations.NamedObject( - "ReductionToOneDeviceCrossDeviceOps", - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()), + "ReductionToOneDevice", + cross_device_ops_lib.ReductionToOneDevice()), combinations.NamedObject( "AllReduceCrossDeviceOps", cross_device_ops_lib.AllReduceCrossDeviceOps()) @@ -314,7 +327,8 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) t1 = _make_indexed_slices( [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1]) - per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) + per_replica = value_lib.PerReplica( + value_lib.ReplicaDeviceMap(devices), (t0, t1)) if batch_reduce: result = cross_device_ops_instance.batch_reduce( @@ -416,6 +430,9 @@ class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase, self._testReductionAndBroadcast(cross_device_ops, distribution) +NUM_WORKERS = 3 + + class MultiWorkerCollectiveAllReduceTest( multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): @@ -423,9 +440,9 @@ class MultiWorkerCollectiveAllReduceTest( @classmethod def setUpClass(cls): - """Create a local cluster with 2 workers.""" + """Create a local cluster with 3 workers.""" cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( - num_workers=3, num_ps=0) + num_workers=NUM_WORKERS, num_ps=0) def setUp(self): super(MultiWorkerCollectiveAllReduceTest, self).setUp() @@ -433,7 +450,12 @@ class MultiWorkerCollectiveAllReduceTest( # collective key base for different tests. MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000 - def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False): + def _get_test_objects(self, + task_type, + task_id, + num_gpus=0, + use_strategy_object=False, + local_mode=False): collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 * num_gpus + MultiWorkerCollectiveAllReduceTest.collective_key_base, @@ -442,16 +464,24 @@ class MultiWorkerCollectiveAllReduceTest( instance_key_with_id_start=num_gpus * 10000 + MultiWorkerCollectiveAllReduceTest.collective_key_base) if local_mode: - collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( - 1, num_gpus, collective_keys=collective_keys) if num_gpus: devices = ["/device:GPU:%d" % i for i in range(num_gpus)] else: devices = ["/device:CPU:0"] - return collective_all_reduce_ops, devices, "" + + if use_strategy_object: + # Still using contrib CollectiveAllReduceStrategy because we can specify + # num_gpus in its constructor. + strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + strategy.extended._collective_keys = collective_keys + strategy.extended._cross_device_ops._collective_keys = collective_keys + return strategy, devices, "" + else: + collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( + 1, num_gpus, collective_keys=collective_keys) + return collective_all_reduce_ops, devices, "" else: - collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( - 3, num_gpus, collective_keys=collective_keys) if num_gpus: devices = [ "/job:%s/task:%d/device:GPU:%d" % (task_type, task_id, i) @@ -459,8 +489,23 @@ class MultiWorkerCollectiveAllReduceTest( ] else: devices = ["/job:%s/task:%d" % (task_type, task_id)] - return (collective_all_reduce_ops, devices, - "grpc://" + self._cluster_spec[task_type][task_id]) + + if use_strategy_object: + strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + strategy.configure( + cluster_spec=self._cluster_spec, + task_type=task_type, + task_id=task_id) + strategy.extended._collective_keys = collective_keys + strategy.extended._cross_device_ops._collective_keys = collective_keys + return (strategy, devices, + "grpc://" + self._cluster_spec[task_type][task_id]) + else: + collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( + NUM_WORKERS, num_gpus, collective_keys=collective_keys) + return (collective_all_reduce_ops, devices, + "grpc://" + self._cluster_spec[task_type][task_id]) def _assert_values_equal(self, left, right, sess): if isinstance(left, list): @@ -474,15 +519,24 @@ class MultiWorkerCollectiveAllReduceTest( run_options.experimental.collective_graph_key = 6 left_values = np.array( - sess.run(list(left._index.values()), options=run_options)).flatten() - right_values = np.array(list(right._index.values())).flatten() + sess.run(list(left.values), options=run_options)).flatten() + right_values = np.array(list(right.values)).flatten() self.assertEqual(len(left_values), len(right_values)) for l, r in zip(left_values, right_values): self.assertEqual(l, r) - def _test_reduction(self, task_type, task_id, num_gpus, local_mode=False): + def _test_reduction(self, + task_type, + task_id, + num_gpus, + use_strategy_object=False, + local_mode=False): collective_all_reduce, devices, master_target = self._get_test_objects( - task_type, task_id, num_gpus, local_mode=local_mode) + task_type, + task_id, + num_gpus, + use_strategy_object=use_strategy_object, + local_mode=local_mode) if local_mode: num_workers = 1 worker_device = None @@ -490,13 +544,34 @@ class MultiWorkerCollectiveAllReduceTest( num_workers = len(self._cluster_spec.get("chief", [])) + len( self._cluster_spec.get("worker", [])) worker_device = "/job:%s/task:%d" % (task_type, task_id) + + def _reduce(test_object, reduce_op, per_replica, destinations): + if use_strategy_object: + with test_object.scope(): + # Mimic the behavior that distribution strategy usually strips the + # wrapper if there is only one value. + if len(per_replica.values) == 1: + per_replica = per_replica.values[0] + return test_object.extended.reduce_to(reduce_op, per_replica, + destinations) + else: + return test_object.reduce(reduce_op, per_replica, destinations) + + def _batch_reduce(test_object, reduce_op, value_destination_pairs): + if use_strategy_object: + with test_object.scope(): + return test_object.extended.batch_reduce_to(reduce_op, + value_destination_pairs) + else: + return test_object.batch_reduce(reduce_op, value_destination_pairs) + with ops.Graph().as_default(), \ ops.device(worker_device), \ self.cached_session(target=master_target) as sess: # Collective ops doesn't support scalar tensors, so we have to construct # 1-d tensors. values = [constant_op.constant([float(d)]) for d in range(len(devices))] - per_replica = _make_per_replica(values, devices, regroup=True) + per_replica = _make_per_replica(values, devices) mean = np.array([(len(devices) - 1.) / 2.]) values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))] @@ -514,26 +589,30 @@ class MultiWorkerCollectiveAllReduceTest( # test reduce() for destinations in all_destinations: self._assert_values_equal( - collective_all_reduce.reduce( + _reduce( + collective_all_reduce, reduce_util.ReduceOp.MEAN, per_replica, - destinations=destinations), - _fake_mirrored(mean, destinations), sess) + destinations=destinations), _fake_mirrored(mean, destinations), + sess) self._assert_values_equal( - collective_all_reduce.reduce( + _reduce( + collective_all_reduce, reduce_util.ReduceOp.MEAN, per_replica_2, - destinations=destinations), - _fake_mirrored(mean_2, destinations), sess) + destinations=destinations), _fake_mirrored( + mean_2, destinations), sess) self._assert_values_equal( - collective_all_reduce.reduce( + _reduce( + collective_all_reduce, reduce_util.ReduceOp.SUM, per_replica, destinations=destinations), _fake_mirrored(mean * len(devices) * num_workers, destinations), sess) self._assert_values_equal( - collective_all_reduce.reduce( + _reduce( + collective_all_reduce, reduce_util.ReduceOp.SUM, per_replica_2, destinations=destinations), @@ -543,17 +622,13 @@ class MultiWorkerCollectiveAllReduceTest( # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( - collective_all_reduce.batch_reduce(reduce_util.ReduceOp.MEAN, - [(per_replica, d1), - (per_replica_2, d2)]), - [ - _fake_mirrored(mean, d1), - _fake_mirrored(mean_2, d2) - ], sess) + _batch_reduce(collective_all_reduce, reduce_util.ReduceOp.MEAN, + [(per_replica, d1), (per_replica_2, d2)]), + [_fake_mirrored(mean, d1), + _fake_mirrored(mean_2, d2)], sess) self._assert_values_equal( - collective_all_reduce.batch_reduce(reduce_util.ReduceOp.SUM, - [(per_replica, d1), - (per_replica_2, d2)]), + _batch_reduce(collective_all_reduce, reduce_util.ReduceOp.SUM, + [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean * len(devices) * num_workers, d1), _fake_mirrored(mean_2 * len(devices) * num_workers, d2) @@ -562,18 +637,36 @@ class MultiWorkerCollectiveAllReduceTest( return True @combinations.generate( - combinations.combine(mode=["graph"], num_gpus=[0, 1, 2], required_gpus=1)) - def testReductionDistributed(self, num_gpus): + combinations.combine( + mode=["graph"], + num_gpus=[0, 1, 2], + required_gpus=1, + use_strategy_object=[True, False])) + def testReductionDistributed(self, num_gpus, use_strategy_object): if context.num_gpus() < num_gpus: return - self._run_between_graph_clients(self._test_reduction, self._cluster_spec, - num_gpus) + self._run_between_graph_clients( + self._test_reduction, + self._cluster_spec, + num_gpus, + use_strategy_object=use_strategy_object) # Collective ops doesn't support strategy with one device. - def testReductionLocal(self, num_gpus=2): + @combinations.generate( + combinations.combine( + mode=["graph"], + num_gpus=[2], + required_gpus=2, + use_strategy_object=[True, False])) + def testReductionLocal(self, num_gpus, use_strategy_object): if context.num_gpus() < num_gpus: return - self._test_reduction(None, None, num_gpus, local_mode=True) + self._test_reduction( + None, + None, + num_gpus, + use_strategy_object=use_strategy_object, + local_mode=True) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/cross_device_utils_test.py b/tensorflow/contrib/distribute/python/cross_device_utils_test.py index 2303a31677afbd12a0b8e7eea3ecf7c7736c46ad..275aac2eeca575e927878d1ece63ce37ed38e8a0 100644 --- a/tensorflow/contrib/distribute/python/cross_device_utils_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_utils_test.py @@ -103,7 +103,8 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1}) + device_map = value_lib.ReplicaDeviceMap(("/gpu:0", "/cpu:0")) + per_replica = value_lib.PerReplica(device_map, (t0, t1)) self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica)) @combinations.generate(combinations.combine( diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index e17085628ba6d1dfc79839fd824801723f07a518..1ff1e7c1d255492e0535175dae7594d2ceb4010b 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -22,7 +22,6 @@ import shutil import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.optimizer_v2 import adagrad @@ -117,7 +116,7 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, scores = estimator.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', six.iterkeys(scores)) + self.assertIn('loss', scores) predictions = np.array([ x[prediction_keys.PredictionKeys.PREDICTIONS] diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index b369a7fefe6f35cf5a9b64451419cf4f72a99471..3f55a8a1c8b88d1b8e4031547fa3fbe519983630 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -375,11 +375,13 @@ class DistributeCoordinatorIntegrationTest( threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, cluster_spec, train_distribute, eval_distribute) + threads_to_join = [] for task_type, ts in threads.items(): if task_type == PS: continue for t in ts: - t.join() + threads_to_join.append(t) + self.join_independent_workers(threads_to_join) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator) @@ -413,8 +415,7 @@ class DistributeCoordinatorIntegrationTest( threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, cluster_spec, train_distribute, eval_distribute) - threads[WORKER][0].join() - threads[EVALUATOR][0].join() + self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]]) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator) diff --git a/tensorflow/contrib/distribute/python/examples/BUILD b/tensorflow/contrib/distribute/python/examples/BUILD index 84b106545e1326fddd3ed299462534af982dc102..5f89df5824a8d03198987a6fa3d21e2330deedf0 100644 --- a/tensorflow/contrib/distribute/python/examples/BUILD +++ b/tensorflow/contrib/distribute/python/examples/BUILD @@ -31,6 +31,12 @@ py_binary( py_binary( name = "keras_mnist", + srcs = ["keras_mnist.py"], + deps = [":keras_mnist_lib"], +) + +py_library( + name = "keras_mnist_lib", srcs = [ "keras_mnist.py", ], @@ -39,3 +45,14 @@ py_binary( "//third_party/py/numpy", ], ) + +py_binary( + name = "mnist_eager_multigpu", + srcs = [ + "mnist_eager_multigpu.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index 60fda996642464135fe1fb8c314bcf7f04d19362..1ce91ecaf22a80a53124c8f00fac05c6b4711ed9 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -109,22 +109,21 @@ def main(_): tf.enable_eager_execution() train_ds, eval_ds, input_shape = get_input_datasets() - model = get_model(input_shape) # Instantiate the MirroredStrategy object. If we don't specify `num_gpus` or # the `devices` argument then all the GPUs available on the machine are used. # TODO(priyag): Use `tf.distribute.MirroredStrategy` once available. strategy = mirrored_strategy.MirroredStrategy(['/gpu:0', '/cpu:0']) - optimizer = rmsprop.RMSProp(learning_rate=0.001) - - # Compile the model by passing the distribution strategy object to the - # `distribute` argument. `fit`, `evaluate` and `predict` will be distributed - # based on the strategy instantiated. - model.compile(loss=tf.keras.losses.categorical_crossentropy, - optimizer=optimizer, - metrics=['accuracy'], - distribute=strategy) + # Create and compile the model under Distribution strategy scope. + # `fit`, `evaluate` and `predict` will be distributed based on the strategy + # model was compiled with. + with strategy.scope(): + model = get_model(input_shape) + optimizer = rmsprop.RMSProp(learning_rate=0.001) + model.compile(loss=tf.keras.losses.categorical_crossentropy, + optimizer=optimizer, + metrics=['accuracy']) # Train the model with the train dataset. model.fit(x=train_ds, epochs=20, steps_per_epoch=468) diff --git a/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py b/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py new file mode 100644 index 0000000000000000000000000000000000000000..c045a5586b9dad371d8c505f9cac4b792dd157fd --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py @@ -0,0 +1,169 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run MNIST on multiple GPUs on using MirroredStrategy with eager execution. + +By default, runs on all available GPUs, or CPU if no GPUs are available. + +NOTE: Currently, this takes more time than when running MNIST in eager without +MirroredStrategy because of a number overheads. Therefore, this is just a +proof of concept right now and cannot be used to actually scale up training. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import app +from absl import flags +import numpy as np +import tensorflow.compat.v2 as tf + +flags.DEFINE_integer("num_gpus", None, "How many GPUs should we run on?" + "Defaults to all available GPUs, otherwise CPU.") +flags.DEFINE_integer("batch_size", 64, + "What should be the size of each batch?") +flags.DEFINE_integer("num_epochs", 10, "How many epochs to run?") +flags.DEFINE_float("learning_rate", 0.01, "Learning Rate") +flags.DEFINE_float("momentum", 0.5, "SGD momentum") +flags.DEFINE_boolean("use_function", False, + "Should we wrap the step in a tf.function.") + +FLAGS = flags.FLAGS +NUM_TRAIN_IMAGES = 60000 + + +def create_model(): + max_pool = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding="same") + # The model consists of a sequential chain of layers, so tf.keras.Sequential + # (a subclass of tf.keras.Model) makes for a compact description. + return tf.keras.Sequential([ + tf.keras.layers.Reshape( + target_shape=[28, 28, 1], + input_shape=(28, 28,)), + tf.keras.layers.Conv2D(2, 5, padding="same", activation=tf.nn.relu), + max_pool, + tf.keras.layers.Conv2D(4, 5, padding="same", activation=tf.nn.relu), + max_pool, + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(32, activation=tf.nn.relu), + tf.keras.layers.Dropout(0.4), + tf.keras.layers.Dense(10)]) + + +def compute_loss(logits, labels): + loss = tf.reduce_sum( + tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels)) + # Scale loss by global batch size. + return loss * (1. / FLAGS.batch_size) + + +def mnist_datasets(): + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() + # Numpy defaults to dtype=float64; TF defaults to float32. Stick with float32. + x_train, x_test = x_train / np.float32(255), x_test / np.float32(255) + y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64) + # TODO(priyag): `strategy.make_numpy_iterator` can be used directly instead of + # converting to datasets. + train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) + test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + return train_dataset, test_dataset + + +def main(unused_argv): + """Run a CNN model on MNIST data to demonstrate DistributedStrategies.""" + + tf.enable_v2_behavior() + + num_gpus = FLAGS.num_gpus + if num_gpus is None: + devices = None + elif num_gpus == 0: + devices = ["/device:CPU:0"] + else: + devices = ["/device:GPU:{}".format(i) for i in range(num_gpus)] + strategy = tf.distribute.MirroredStrategy(devices) + + with strategy.scope(): + train_ds, test_ds = mnist_datasets() + train_ds = train_ds.shuffle(NUM_TRAIN_IMAGES).batch(FLAGS.batch_size) + test_ds = test_ds.batch(FLAGS.batch_size) + + model = create_model() + optimizer = tf.keras.optimizers.SGD(FLAGS.learning_rate, FLAGS.momentum) + training_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32) + training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( + "training_accuracy", dtype=tf.float32) + test_loss = tf.keras.metrics.Mean("test_loss", dtype=tf.float32) + test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( + "test_accuracy", dtype=tf.float32) + + def train_step(inputs): + images, labels = inputs + with tf.GradientTape() as tape: + logits = model(images, training=True) + loss = compute_loss(logits, labels) + grads = tape.gradient(loss, model.variables) + optimizer.apply_gradients(zip(grads, model.variables)) + training_loss.update_state(loss) + training_accuracy.update_state(labels, logits) + + def test_step(inputs): + images, labels = inputs + logits = model(images, training=False) + loss = compute_loss(logits, labels) + test_loss.update_state(loss) + test_accuracy.update_state(labels, logits) + + train_iterator = strategy.make_dataset_iterator(train_ds) + test_iterator = strategy.make_dataset_iterator(test_ds) + + for epoch in range(0, FLAGS.num_epochs): + # TODO(b/123315763): Create the tf.function outside this loop once we are + # able to initialize iterator in eager mode. + dist_train = lambda it: strategy.experimental_run(train_step, it) + dist_test = lambda it: strategy.experimental_run(test_step, it) + if FLAGS.use_function: + dist_train = tf.function(dist_train) + dist_test = tf.function(dist_test) + + # Train + print("Starting epoch {}".format(epoch)) + train_iterator.initialize() + while True: + try: + dist_train(train_iterator) + except tf.errors.OutOfRangeError: + break + print("Training loss: {:0.4f}, accuracy: {:0.2f}%".format( + training_loss.result(), training_accuracy.result() * 100)) + training_loss.reset_states() + training_accuracy.reset_states() + + # Test + test_iterator.initialize() + while True: + try: + dist_test(test_iterator) + except tf.errors.OutOfRangeError: + break + print("Test loss: {:0.4f}, accuracy: {:0.2f}%".format( + test_loss.result(), test_accuracy.result() * 100)) + test_loss.reset_states() + test_accuracy.reset_states() + + +if __name__ == "__main__": + app.run(main) diff --git a/tensorflow/contrib/distribute/python/input_lib_test.py b/tensorflow/contrib/distribute/python/input_lib_test.py new file mode 100644 index 0000000000000000000000000000000000000000..204f52b034f2366a42fbdab41c467feddb5969a0 --- /dev/null +++ b/tensorflow/contrib/distribute/python/input_lib_test.py @@ -0,0 +1,217 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the input_lib library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import values +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import errors +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.util import nest + + +class InputIteratorTestBase(test.TestCase): + + def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, split_batch_by=None): + devices = nest.flatten([ds for _, ds in worker_device_pairs]) + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) + + if input_type == "input_fn": + input_contexts = [ + distribute_lib.InputContext() for _ in worker_device_pairs] + input_fn = lambda _: dataset_fn() + iterator = input_lib.InputFunctionIterator( + input_fn, input_workers, input_contexts) + else: + iterator = input_lib.DatasetIterator( + dataset_fn(), input_workers, split_batch_by) + + evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) + self.assertAllEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate([values.select_replica(r, next_element) + for r in range(len(devices))]) + + # After re-initializing the iterator, should be able to iterate again. + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) + self.assertAllEqual(expected_value, computed_value) + + +class InputIteratorSingleWorkerTest(InputIteratorTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"])) + def testOneDeviceCPU(self, input_type): + worker_device_pairs = [("", ["/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesOneGPUOneCPU(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTupleDataset(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testUnevenDatasetBatches(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["dataset"], + split_batch_by=[None, 2], + required_gpus=1)) + def testBatchSplitting(self, input_type, split_batch_by): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + batch_size = 10 + dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) + + updated_batch_size = ( + batch_size // split_batch_by if split_batch_by else batch_size) + expected_values = [[range(i, i+updated_batch_size), + range(i+updated_batch_size, i+2*updated_batch_size)] + for i in range(0, 100, updated_batch_size*2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, + split_batch_by=split_batch_by) + + +class InputIteratorMultiWorkerTest( + multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, + parameterized.TestCase): + + def _cpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])] + + def _cpu_and_one_gpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0" + ]), + ("/job:worker/replica:0/task:1", [ + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ]) + ] + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testOneDevicePerWorker(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 0], [1, 1], [2, 2], [3, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesPerWorker(self, input_type): + worker_devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 1, 0, 1], [2, 3, 2, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testTupleDataset(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(4) + dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] + self._test_iterator(input_type, dataset_fn, worker_devices, + expected_values, sess) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c49b5522f9135efd9ae3005e92099caf54a76a3a --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py @@ -0,0 +1,1083 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.keras models using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.eager import test +from tensorflow.python.framework import random_seed +from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.keras.utils.mode_keys import ModeKeys +from tensorflow.python.ops.parsing_ops import gen_parsing_ops +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import rmsprop + +_RANDOM_SEED = 1337 +_TRAIN_SIZE = 200 +_INPUT_SIZE = (10,) +_NUM_CLASS = 2 + + +# TODO(anjalisridhar): Add a decorator that will allow us to run these tests as +# part of the tf.keras unit tests suite. +def simple_sequential_model(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE)) + model.add(keras.layers.Dropout(0.1)) + model.add(keras.layers.Dense(_NUM_CLASS, activation='softmax')) + return model + + +def simple_functional_model(): + a = keras.layers.Input(shape=_INPUT_SIZE) + b = keras.layers.Dense(16, activation='relu')(a) + b = keras.layers.Dropout(0.1)(b) + b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b) + model = keras.models.Model(inputs=[a], outputs=[b]) + return model + + +def multi_inputs_multi_outputs_model(): + input_a = keras.layers.Input(shape=(16,), name='input_a') + input_b = keras.layers.Input(shape=(16,), name='input_b') + input_m = keras.layers.Input(shape=(8,), dtype='string', name='input_m') + dense = keras.layers.Dense(8, name='dense_1') + + interm_a = dense(input_a) + # Read m + interm_m = keras.layers.Lambda(gen_parsing_ops.string_to_number)(input_m) + interm_s = keras.layers.Lambda(lambda k: k[0] * k[1])([interm_m, interm_a]) + interm_b = dense(input_b) + merged = keras.layers.concatenate([interm_s, interm_b], name='merge') + output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged) + output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged) + model = keras.models.Model( + inputs=[input_a, input_b, input_m], outputs=[output_c, output_d]) + model.compile( + loss='categorical_crossentropy', + optimizer=gradient_descent.GradientDescentOptimizer(0.001), + metrics={ + 'dense_2': 'categorical_accuracy', + 'dense_3': 'categorical_accuracy' + }) + return model + + +def get_ds_train_input_fn(): + np.random.seed(_RANDOM_SEED) + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_train = keras.utils.to_categorical(y_train) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + dataset = dataset.batch(32) + return dataset + + +def get_ds_test_input_fn(): + np.random.seed(_RANDOM_SEED) + _, (x_test, y_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_test = keras.utils.to_categorical(y_test) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_test, y_test)) + dataset = dataset.batch(32) + return dataset + + +def get_multi_inputs_multi_outputs_data(): + (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(16,), + num_classes=3, + random_seed=_RANDOM_SEED) + (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(16,), + num_classes=2, + random_seed=_RANDOM_SEED) + (m_train, _), (m_test, _) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(8,), + num_classes=2, + random_seed=_RANDOM_SEED) + + c_train = keras.utils.to_categorical(c_train) + c_test = keras.utils.to_categorical(c_test) + d_train = keras.utils.to_categorical(d_train) + d_test = keras.utils.to_categorical(d_test) + + train_data = { + 'input_a': a_train, + 'input_b': b_train, + 'input_m': m_train, + 'output_c': c_train, + 'output_d': d_train + } + test_data = { + 'input_a': a_test, + 'input_b': b_test, + 'input_m': m_test, + 'output_c': c_test, + 'output_d': d_test + } + + return (train_data, test_data) + + +def batch_wrapper(dataset, batch_size, distribution, repeat=None): + if repeat: + dataset = dataset.repeat(repeat) + # TPUs currently require fully defined input shapes, drop_remainder ensures + # the input will have fully defined shapes. + if isinstance(distribution, tpu_strategy.TPUStrategy): + return dataset.batch(batch_size, drop_remainder=True) + else: + return dataset.batch(batch_size) + + +def get_model(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + return model + + +def get_dataset(distribution): + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = batch_wrapper(dataset, 10, distribution) + return dataset + + +def get_predict_dataset(distribution): + inputs = np.zeros((10, 3), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + dataset = dataset.repeat(100) + dataset = batch_wrapper(dataset, 10, distribution) + return dataset + + +def multi_input_output_model(): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(5,), name='input_b') + # TODO(anjalisridhar): Change the output dimension of the second Dense layer + # once the iterator output validation issue has been fixed. + dense_1 = keras.layers.Dense(7, name='dense_1') + dense_2 = keras.layers.Dense(7, name='dense_2') + c = dense_1(a) + d = dense_2(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + model = keras.models.Model([a, b], [d, e]) + return model + + +def get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, + x_train, y_train, x_predict): + """Generates the inputs for correctness check when enable Keras with DS.""" + training_epochs = 2 + global_batch_size = 64 + batch_size = global_batch_size + # TODO(b/118776054): Use global batch size for Keras/DS support. + use_per_core_batch_size = ( + with_distribution and + not distributed_training_utils.global_batch_size_supported( + with_distribution)) + if use_per_core_batch_size: + batch_size //= with_distribution.num_replicas_in_sync + + if use_numpy: + training_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + 'epochs': training_epochs, + 'shuffle': False, + } + + if use_validation_data: + eval_inputs = None + training_inputs['validation_data'] = (x_train, y_train) + else: + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } + predict_inputs = { + 'x': np.array(x_predict, dtype=np.float32), + } + else: + # For dataset inputs, we do not pass batch_size to + # keras.fit/evaluate/predict. The batch size is part of the dataset. + train_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper( + train_dataset, batch_size, with_distribution, repeat=training_epochs) + + training_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'epochs': training_epochs, + 'shuffle': False, + 'steps_per_epoch': len(x_train) // global_batch_size, + } + if use_validation_data: + eval_inputs = None # Remove the eval_inputs + eval_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper(eval_dataset, batch_size, with_distribution) + training_inputs['validation_data'] = x + training_inputs['validation_steps'] = 5 + else: + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': 20, + } + + predict_batch_size = len(x_predict) + if use_per_core_batch_size: + predict_batch_size //= with_distribution.num_replicas_in_sync + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) + predict_dataset = batch_wrapper(predict_dataset, + predict_batch_size, with_distribution) + predict_inputs = { + 'steps': 1, + 'x': predict_dataset, + } + + return training_inputs, eval_inputs, predict_inputs + + +strategies_minus_tpu = [ + combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus] + +tpu_strategies = [ + combinations.tpu_strategy, # steps_per_run=2 + combinations.tpu_strategy_one_step] + + +def strategy_minus_tpu_combinations(): + return combinations.combine( + distribution=strategies_minus_tpu, + mode=['graph', 'eager']) + + +def tpu_strategy_combinations(): + return combinations.combine( + distribution=tpu_strategies, + mode=['graph']) + + +def all_strategy_combinations(): + return strategy_minus_tpu_combinations() + tpu_strategy_combinations() + + +def strategy_and_optimizer_combinations(): + return combinations.times( + all_strategy_combinations(), + combinations.combine(optimizer=[ + combinations.adagrad_optimizer_v1_fn, + combinations.adagrad_optimizer_keras_v2_fn, + combinations.adam_optimizer_v1_fn, + combinations.adam_optimizer_keras_v2_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_keras_v2_fn, + combinations.rmsprop_optimizer_v1_fn, + combinations.rmsprop_optimizer_keras_v2_fn + ])) + + +def strategy_and_input_combinations(): + return ( + combinations.times( + combinations.combine(distribution=strategies_minus_tpu), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]) + + combinations.combine(mode=['eager'], + use_numpy=[False], + use_validation_data=[False])) + + combinations.times( + combinations.combine(distribution=tpu_strategies), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]))) + + +def strategy_for_numpy_input_combinations(): + return combinations.combine( + distribution=strategies_minus_tpu + tpu_strategies, + mode=['graph']) + + +class TestDistributionStrategyWithNumpyArrays(test.TestCase, + parameterized.TestCase): + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calling_model_with_numpy_arrays(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, + validation_data=(inputs, targets)) + + # TODO(anjalisridhar): We need tests for when the batch size and steps are + # smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) + + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calling_model_with_nested_numpy_arrays(self, distribution): + with self.cached_session(): + model = multi_input_output_model() + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32) + input_b_np = np.asarray(np.random.random((64, 5)), dtype=np.float32) + inputs = [input_a_np, input_b_np] + + output_d_np = np.asarray(np.random.random((64, 7)), dtype=np.float32) + output_e_np = np.asarray(np.random.random((64, 7)), dtype=np.float32) + targets = [output_d_np, output_e_np] + + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=8, verbose=0) + + # TODO(anjalisridhar): We need tests for when the batch size and steps are + # smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) + + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) + + @combinations.generate(combinations.combine( + distribution=strategies_minus_tpu, mode=['graph'])) + def test_numpy_with_sample_weights(self, distribution): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + inputs = np.zeros((20, 3), np.float32) + targets = np.zeros((20, 4), np.float32) + sample_weights = np.ones((20), np.float32) + + model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, + steps_per_epoch=2, verbose=1) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_flatten_predict_outputs(self, distribution): + with self.cached_session(): + model = multi_input_output_model() + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + # We take 6 input samples with each input having a dimension of 3 or 5. + input_a_np = np.asarray(np.random.random((6, 3)), dtype=np.float32) + input_b_np = np.asarray(np.random.random((6, 5)), dtype=np.float32) + inputs = [input_a_np, input_b_np] + + outs = model.predict(inputs, steps=1) + # `predict` a list that is equal in length to the number of model outputs. + # In this test our model has two outputs and each element of `outs` + # corresponds to all the samples of one of the model outputs. + self.assertLen(outs, 2) + # Each of the output samples have a dimension of 7. We should process all + # the available input samples(6). + self.assertAllEqual([6, 7], outs[0].shape) + self.assertAllEqual([6, 7], outs[1].shape) + + +class TestDistributionStrategyWithDatasets(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations()) + def test_calling_model_on_same_dataset(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + dataset = get_dataset(distribution) + + # Call fit with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + model.predict(get_predict_dataset(distribution), steps=2) + + @combinations.generate(all_strategy_combinations()) + def test_model_interleaved_eval_same_as_direct_eval(self, distribution): + with self.cached_session(): + user_controlled_model = get_model() + user_controlled_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) + + interleaved_model = get_model() + interleaved_model.set_weights(user_controlled_model.get_weights()) + interleaved_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) + + dataset = get_dataset(distribution) + + # Call fit with validation interleaved + interleaved_output = interleaved_model.fit( + dataset, epochs=2, steps_per_epoch=2, verbose=1, + validation_data=dataset, validation_steps=2, shuffle=False) + + # Manually control the validation running after each epoch. + user_controlled_output = [] + for _ in range(2): + user_controlled_model.fit( + dataset, epochs=1, steps_per_epoch=2, verbose=1, shuffle=False) + user_controlled_output.append( + user_controlled_model.evaluate(dataset, steps=2)) + + self.assertEqual(interleaved_output.history['val_loss'], + [x[0] for x in user_controlled_output]) + self.assertEqual(interleaved_output.history['val_mean_absolute_error'], + [x[1] for x in user_controlled_output]) + self.assertEqual(interleaved_output.history['val_categorical_accuracy'], + [x[2] for x in user_controlled_output]) + + # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work + # as clone_model's input_tensors argument only seems to accept list and not + # tuples or dict. + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution): + with self.cached_session(): + model = multi_input_output_model() + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 5)) + output_d_np = np.random.random((10, 7)) + output_e_np = np.random.random((10, 7)) + + # Test with tuples + dataset_tuple = dataset_ops.Dataset.from_tensor_slices(( + (input_a_np, input_b_np), (output_d_np, output_e_np))) + dataset_tuple = dataset_tuple.repeat(100) + dataset_tuple = dataset_tuple.batch(10) + + model.fit(dataset_tuple, epochs=1, steps_per_epoch=2, verbose=1) + + # Test with dict + dataset_dict = dataset_ops.Dataset.from_tensor_slices(( + {'input_a': input_a_np, 'input_b': input_b_np}, + (output_d_np, output_e_np))) + dataset_dict = dataset_dict.repeat(100) + dataset_dict = dataset_dict.batch(10) + + model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) + + @combinations.generate(all_strategy_combinations()) + def test_fit_eval_and_predict_methods_on_dataset(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + dataset = get_dataset(distribution) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(get_predict_dataset(distribution), steps=2) + + @combinations.generate(strategy_and_optimizer_combinations()) + def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer): + with self.cached_session(): + model = get_model() + + loss = 'mse' + model.compile(optimizer(), loss, distribute=distribution) + + dataset = get_dataset(distribution) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(get_predict_dataset(distribution), steps=2) + + @combinations.generate(strategy_minus_tpu_combinations()) + def test_dataset_with_sample_weights(self, distribution): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + inputs = np.zeros((10, 3), np.float32) + targets = np.zeros((10, 4), np.float32) + sample_weights = np.ones((10), np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets, + sample_weights)) + dataset = dataset.repeat() + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_wrong_input_shape(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + # Wrong input shape + inputs = np.zeros((10, 5), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + @combinations.generate(combinations.combine( + distribution=[combinations.mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_no_batch_input_validation(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + # User forgets to batch the dataset + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + + with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + @combinations.generate(combinations.combine( + distribution=[combinations.tpu_strategy_one_step], + mode=['graph'])) + def test_dataset_input_shape_fully_defined(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + dataset = get_dataset(distribution) + # Input shapes are not fully known. Batch dimension is unknown as we are + # not using the drop_remainder argument. + dataset = dataset.repeat(100).batch(10) + + with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_learning_phase_value(self, distribution): + # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare + # meaningful values. Currently we don't pass the learning phase if the + # Lambda layer uses the learning phase. + with self.cached_session(): + x = keras.layers.Input(shape=(1,), name='input') + y = keras.layers.Dense(1, kernel_initializer='ones')(x) + z = keras.layers.Dropout(0.9999)(y) + model = keras.Model(x, z) + initial_weights = model.get_weights() + + optimizer = gradient_descent.GradientDescentOptimizer(0.005) + loss = 'mse' + metrics = ['acc'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + batch_size = 8 + if isinstance(distribution, mirrored_strategy.CoreMirroredStrategy): + # CoreMirroredStrategy uses global batch size. + batch_size = 8 * distribution.num_replicas_in_sync + + inputs = np.ones((10, 1), dtype=np.float32) + targets = np.ones((10, 1), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat().batch(batch_size) + hist = model.fit(dataset, epochs=1, steps_per_epoch=20, verbose=1) + self.assertAlmostEqual(hist.history['acc'][0], 0, 0) + + model.set_weights(initial_weights) + # TODO(psv/anjalisridhar): Enable these lines after we fix b/117431185. + # evaluate_output = model.evaluate(dataset, steps=20) + # self.assertAlmostEqual(evaluate_output[1], 1, 0) + + inputs = np.ones((10, 1), dtype=np.float32) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + + predict_dataset = predict_dataset.repeat().batch(batch_size) + output = model.predict(predict_dataset, steps=10) + # `predict` runs for 10 steps + ref_output = np.ones((160, 1), dtype=np.float32) + self.assertArrayNear(output, ref_output, 1e-1) + + @combinations.generate(strategy_minus_tpu_combinations()) + def testOptimizerWithCallbacks(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent_keras.SGD(0.01) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + dataset = get_dataset(distribution) + + def schedule(_): + return 0.001 + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + grouped_models = distribution.unwrap( + distributed_training_utils.get_distributed_model( + model, ModeKeys.TRAIN)) + with distribution.scope(): + for m in grouped_models: + self.assertAllClose(0.001, keras.backend.get_value( + m.optimizer.lr), atol=1e-05, rtol=1e-05) + + +class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_unsupported_features(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + dataset = get_dataset(distribution) + + # Test with validation split + with self.assertRaisesRegexp( + ValueError, '`validation_split` argument is not ' + 'supported when input `x` is a dataset or a ' + 'dataset iterator.+'): + model.fit(dataset, + epochs=1, steps_per_epoch=2, verbose=0, + validation_split=0.5, validation_steps=2) + + # Test with sample weight. + sample_weight = np.random.random((10,)) + with self.assertRaisesRegexp( + ValueError, '`sample_weight` argument is not supported when input ' + '`x` is a dataset or a dataset iterator.'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + sample_weight=sample_weight) + + # Test with not specifying the `steps` argument for dataset with + # infinite cardinality. + dataset = dataset.repeat() + with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps_per_epoch` argument'): + model.fit(dataset, epochs=1, verbose=0) + with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps` argument'): + model.evaluate(dataset, verbose=0) + + with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps` argument'): + model.predict(dataset, verbose=0) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_calling_with_unsupported_predefined_callbacks(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + dataset = get_dataset(distribution) + + def schedule(_): + return 0.001 + with self.assertRaisesRegexp(ValueError, + 'You must specify a Keras Optimizer V2 when ' + 'using'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + + with self.assertRaisesRegexp(ValueError, + 'You must specify a Keras Optimizer V2 when ' + 'using'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.ReduceLROnPlateau()]) + + +class TestDistributionStrategyWithLossMasking(test.TestCase, + parameterized.TestCase): + + # TODO(priyag): Enable all strategies for this test. Currently it does not + # work for TPU due to some invalid datatype. + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_masking(self, distribution): + with self.cached_session(): + np.random.seed(1337) + x = np.array([[[1], [1]], [[0], [0]]]) + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))) + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=distribution) + y = np.array([[[1], [1]], [[1], [1]]]) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2) + self.assertEqual(hist.history['loss'][0], 0) + + +class TestDistributionStrategyWithNormalizationLayer( + test.TestCase, parameterized.TestCase): + + @combinations.generate(all_strategy_combinations()) + def test_batchnorm_correctness(self, distribution): + with self.cached_session(): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) + model.add(norm) + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=distribution) + + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) + x = x.astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) + dataset = dataset.repeat(100) + dataset = batch_wrapper(dataset, 32, distribution) + + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x) + predict_dataset = predict_dataset.repeat(100) + predict_dataset = batch_wrapper(predict_dataset, 32, distribution) + + model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) + out = model.predict(predict_dataset, steps=2) + out -= keras.backend.eval(norm.beta) + out /= keras.backend.eval(norm.gamma) + np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) + np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + + +class TestDistributionStrategyCorrectness(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations()) + def test_metric_correctness(self, distribution): + with self.cached_session(): + keras.backend.set_image_data_format('channels_last') + num_samples = 10000 + + x_train = np.random.randint(0, 2, num_samples) + x_train = np.reshape(x_train, (num_samples, 1)) + y_train = x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + + # Create identity model. + model = keras.Sequential() + model.add( + keras.layers.Dense(1, input_shape=(1,), kernel_initializer='ones')) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + metrics=[keras.metrics.BinaryAccuracy()], + distribute=distribution) + + batch_size = 64 + if not distributed_training_utils.global_batch_size_supported( + distribution): + batch_size //= distribution.num_replicas_in_sync + train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + train_dataset = batch_wrapper(train_dataset, batch_size, distribution) + + history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) + self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) + + @combinations.generate(all_strategy_combinations()) + def test_eval_metrics_correctness(self, distribution): + with self.cached_session(): + model = keras.Sequential() + model.add( + keras.layers.Dense( + 3, activation='relu', input_dim=4, kernel_initializer='ones')) + model.add( + keras.layers.Dense( + 1, activation='sigmoid', kernel_initializer='ones')) + model.compile( + loss='mae', + metrics=['accuracy', keras.metrics.BinaryAccuracy()], + optimizer=gradient_descent.GradientDescentOptimizer(0.001), + distribute=distribution) + + # verify correctness of stateful and stateless metrics. + x = np.ones((100, 4)).astype('float32') + y = np.ones((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 1.) + self.assertEqual(outs[2], 1.) + + y = np.zeros((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 0.) + self.assertEqual(outs[2], 0.) + + @combinations.generate(strategy_and_input_combinations()) + def test_correctness(self, distribution, use_numpy, use_validation_data): + with self.cached_session(): + default_tolerance = 1e-5 + tol_table = {} + + if isinstance(distribution, ( + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, + distribute_lib._DefaultDistributionStrategy)): # pylint: disable=protected-access + # TODO(b/119257215): Weights are not exactly the same, so use larger + # tolerance for now. Predict should be related to weights. + tol_table = { + 'weights_1': 1e-4, + 'weights_2': 1e-4, + 'predict_result_1': 1e-4, + } + + keras.backend.set_image_data_format('channels_last') + np.random.seed(_RANDOM_SEED) + random_seed.set_random_seed(_RANDOM_SEED) + + # Train, eval, and predict datasets are created with the same input numpy + # arrays. + # TODO(xiejw): Change this back to 10000, once we support final partial + # batch. + num_samples = 9984 + x_train = np.random.rand(num_samples, 1) + y_train = 3 * x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + x_predict = [[1.], [2.], [3.], [4.]] + + # The model is built once and the initial weights are saved. + # This is used to initialize the model for both the distribution and + # non-distribution run. In addition, we add few non-linear layers to make + # it non-trivial. + def _create_model(): + model = keras.Sequential() + model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(1)) + return model + + model = _create_model() + initial_weights = model.get_weights() + del model # avoid accident usage. + + def fit_eval_and_predict(with_distribution=None): + model = _create_model() + # We have initialized the model to the same weight for the distribution + # and non-distribution run. + model.set_weights(initial_weights) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent_keras.SGD(0.5), + metrics=['mse'], + distribute=with_distribution) + + training_inputs, eval_inputs, predict_inputs = ( + get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, + x_train, y_train, x_predict)) + + result = {} + result['training_history_1'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_1'] = model.evaluate(**eval_inputs) + + result['weights_1'] = model.get_weights() + result['predict_result_1'] = model.predict(**predict_inputs) + + # Train and eval again to mimic user's flow. + + result['training_history_2'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_2'] = model.evaluate(**eval_inputs) + + result['weights_2'] = model.get_weights() + + return result + + results_with_ds = fit_eval_and_predict(with_distribution=distribution) + results_without_ds = fit_eval_and_predict(with_distribution=None) + + # Verify that the weights, training history, eval results, predict outputs + # are the same within some limits of tolerance. + for key in results_with_ds: + if (key.startswith('training_history') and + isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + # TODO(b/119894254): Enable this test for all cases once the + # underlying bug is fixed. + continue + + tolerance = tol_table.get(key, default_tolerance) + + self.assertAllClose( + results_with_ds[key], + results_without_ds[key], + atol=tolerance, + rtol=tolerance, + msg='Fail to assert {}.'.format(key)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_correctness_test_base.py b/tensorflow/contrib/distribute/python/keras_correctness_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb7a18c40484ce01a5acfd6b191de464cfd9840 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_correctness_test_base.py @@ -0,0 +1,487 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness tests for tf.keras using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np +import six + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import random_seed +from tensorflow.python.keras.engine import distributed_training_utils + +_RANDOM_SEED = 1337 +_EVAL_STEPS = 20 +_GLOBAL_BATCH_SIZE = 64 + +# Note: Please make sure the tests in this file are also covered in +# keras_backward_compat_test for features that are supported with both APIs. + + +all_strategies = [ + combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus, + combinations.tpu_strategy, # steps_per_run=2 + combinations.tpu_strategy_one_step, +] + + +def eager_mode_test_configuration(): + return combinations.combine(mode='eager', + use_numpy=False, + use_validation_data=False) + + +def graph_mode_test_configuration(): + return combinations.combine(mode='graph', + use_numpy=[True, False], + use_validation_data=[True, False]) + + +def all_strategy_and_input_config_combinations(): + return ( + combinations.times( + combinations.combine(distribution=all_strategies), + eager_mode_test_configuration() + graph_mode_test_configuration())) + + +def strategies_for_embedding_models(): + """Returns distribution strategies to test for embedding models. + + Since embedding models take longer to train, we disregard OneDeviceStrategy + and DefaultStrategy in order to prevent testing timeouts. + """ + + return [s for s in all_strategies if s.required_tpu or s.required_gpus] + + +def test_combinations_for_embedding_model(): + return ( + combinations.times( + combinations.combine(distribution= + strategies_for_embedding_models()), + (graph_mode_test_configuration() + + eager_mode_test_configuration()))) + + +def test_combinations_with_tpu_strategies(): + tpu_strategies = [combinations.tpu_strategy, + combinations.tpu_strategy_one_step] + + return ( + combinations.times( + combinations.combine(distribution=tpu_strategies), + graph_mode_test_configuration())) + + +class MaybeDistributionScope(object): + """Provides a context allowing no distribution strategy.""" + + def __init__(self, distribution): + self._distribution = distribution + self._scope = None + + def __enter__(self): + if self._distribution: + self._scope = self._distribution.scope() + self._scope.__enter__() + + def __exit__(self, exc_type, value, traceback): + if self._distribution: + self._scope.__exit__(exc_type, value, traceback) + self._scope = None + + +def batch_wrapper(dataset, batch_size, distribution, repeat=None): + if repeat: + dataset = dataset.repeat(repeat) + # TPUs currently require fully defined input shapes, drop_remainder ensures + # the input will have fully defined shapes. + if isinstance(distribution, tpu_strategy.TPUStrategy): + return dataset.batch(batch_size, drop_remainder=True) + else: + return dataset.batch(batch_size) + + +def get_batch_size(global_batch_size, distribution): + batch_size = global_batch_size + # TODO(b/118776054): Use global batch size for Keras/DS support. + use_per_core_batch_size = ( + distribution and + not distributed_training_utils.global_batch_size_supported(distribution)) + if use_per_core_batch_size: + batch_size //= distribution.num_replicas_in_sync + return batch_size + + +def get_data_size(data): + """Gets the size of data in list, tuple, dict, or a numpy array.""" + assert isinstance(data, (np.ndarray, list, dict, tuple)) + + if isinstance(data, np.ndarray): + return len(data) + + if isinstance(data, (list, tuple)): + return len(data[0]) + + return len(six.next(six.itervalues(data))) + + +def get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, x_train, y_train, x_predict): + """Generates the inputs for correctness check when enable Keras with DS.""" + training_epochs = 2 + global_batch_size = _GLOBAL_BATCH_SIZE + batch_size = get_batch_size(global_batch_size, with_distribution) + + if use_numpy: + training_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + 'epochs': training_epochs, + 'shuffle': False, + } + + if use_validation_data: + eval_inputs = None + training_inputs['validation_data'] = (x_train, y_train) + else: + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } + predict_inputs = { + 'x': x_predict + } + else: + training_data_size = get_data_size(x_train) + if training_data_size < _GLOBAL_BATCH_SIZE * _EVAL_STEPS: + # Currently, we cannot detect the size of a dataset. So, the eval steps is + # hard coded. + raise ValueError('x_train must have at least ' + '_GLOBAL_BATCH_SIZE * _EVAL_STEPS samples') + # For dataset inputs, we do not pass batch_size to + # keras.fit/evaluate/predict. The batch size is part of the dataset. + train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + x = batch_wrapper(train_dataset, batch_size, with_distribution, + repeat=training_epochs) + + training_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'epochs': training_epochs, + 'shuffle': False, + 'steps_per_epoch': training_data_size // global_batch_size, + } + if use_validation_data: + eval_inputs = None # Remove the eval_inputs + eval_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + x = batch_wrapper(eval_dataset, batch_size, with_distribution) + training_inputs['validation_data'] = x + training_inputs['validation_steps'] = 5 + else: + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': _EVAL_STEPS, + } + + predict_batch_size = get_batch_size(get_data_size(x_predict), + with_distribution) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) + predict_dataset = batch_wrapper(predict_dataset, predict_batch_size, + with_distribution) + predict_inputs = { + 'steps': 1, + 'x': predict_dataset, + } + + return training_inputs, eval_inputs, predict_inputs + + +def fit_eval_and_predict(initial_weights, input_fn, model_fn, + distribution=None, is_stateful_model=False): + """Generates results for fit/predict/evaluate for given model.""" + model = model_fn(initial_weights=initial_weights, distribution=distribution) + training_inputs, eval_inputs, predict_inputs = input_fn(distribution) + + result = {} + result['training_history_1'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_1'] = model.evaluate(**eval_inputs) + + result['weights_1'] = model.get_weights() + + if predict_inputs is not None: + # Check correctness of the result of predict() invoked + # multiple times -- as for stateful models, result of + # predict may differ for each batch. + predict_length = 1 + if is_stateful_model: + predict_length = 3 + for i in range(predict_length): + result_key = 'predict_result_{}'.format(i) + result[result_key] = model.predict(**predict_inputs) + + # Train and eval again to mimic user's flow. + + result['training_history_2'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_2'] = model.evaluate(**eval_inputs) + + result['weights_2'] = model.get_weights() + + return result + + +def compare_results(results_with_ds, results_without_ds, distribution, + testcase): + """Compares results of model compiled with/without distribution strategy.""" + + default_tolerance = 1e-5 + relaxed_tolerance = 1e-4 + + def _get_compare_result_tolerance(key): + """Returns tolerance to compare results.""" + # TODO(b/119257215): For MirroredStrategy, weights are not exactly the same, + # so use larger tolerance for now. Predict should be related to weights. + if (isinstance(distribution, ( + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, + distribute_lib._DefaultDistributionStrategy)) and # pylint: disable=protected-access + key.startswith(('weights_1', 'weights_2', 'predict_result'))): + return relaxed_tolerance + + return default_tolerance + + for key in results_with_ds: + if (key.startswith('training_history') and + isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + # TODO(b/119894254): Enable this test for all cases once the + # underlying bug is fixed. + continue + + tolerance = _get_compare_result_tolerance(key) + testcase.assertAllClose( + results_with_ds[key], + results_without_ds[key], + atol=tolerance, + rtol=tolerance, + msg='Fail to assert {}.'.format(key)) + + +def should_skip_tpu_with_eager(distribution): + return (context.executing_eagerly() and + isinstance(distribution, tpu_strategy.TPUStrategy)) + + +class LearningRateBatchScheduler(keras.callbacks.Callback): + """Scheduler that dynamically sets the learning rate of model.""" + + def __init__(self, update_freq=None): + self._update_freq = update_freq + + def on_batch_begin(self, batch, logs=None): + if self._update_freq and batch % self._update_freq != 0: + return + + # To avoid divergence, limit the value range. + lr = 0.001 * (batch % 10) + keras.backend.set_value(self.model.optimizer.lr, lr) + + +class TestDistributionStrategyCorrectnessBase(test.TestCase, + parameterized.TestCase): + """Model agnostic testing infra to test correctness of Keras models.""" + + def set_up_test_config(self, use_numpy=False, + use_validation_data=False, + with_batch_norm=False): + self.use_numpy = use_numpy + self.use_validation_data = use_validation_data + self.with_batch_norm = with_batch_norm + + keras.backend.set_image_data_format('channels_last') + np.random.seed(_RANDOM_SEED) + random_seed.set_random_seed(_RANDOM_SEED) + + def get_data(self): + num_samples = 10000 + x_train = np.random.randint(0, 2, num_samples) + x_train = np.reshape(x_train, (num_samples, 1)) + y_train = x_train + return (x_train.astype('float32'), y_train.astype('float32'), None) + + def get_model(self, distribution=None): + raise NotImplementedError + + def skip_unsupported_test_configuration(self, distribution): + if should_skip_tpu_with_eager(distribution): + self.skipTest('TPUStrategy does not support eager mode now.') + + if context.executing_eagerly() and self.use_numpy: + self.skipTest('Numpy as inputs is not supported with strategy in eager.') + + if context.executing_eagerly() and self.use_validation_data: + self.skipTest('TODO(hongjunchoi): Add test logic for using validation ' + 'data for eager execution.') + return + + def run_correctness_test(self, + distribution, + use_numpy, + use_validation_data, + with_batch_norm=False, + is_stateful_model=False): + with self.cached_session(): + self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm) + self.skip_unsupported_test_configuration(distribution) + + # Train, eval, and predict datasets are created with the same input numpy + # arrays. + x_train, y_train, x_predict = self.get_data() + + # The model is built once and the initial weights are saved. + # This is used to initialize the model for both the distribution and + # non-distribution run. + model = self.get_model() + initial_weights = model.get_weights() + + def input_fn(dist): + return get_correctness_test_inputs( + use_numpy, use_validation_data, dist, x_train, y_train, x_predict) + + results_with_ds = fit_eval_and_predict( + initial_weights, input_fn=input_fn, model_fn=self.get_model, + distribution=distribution, is_stateful_model=is_stateful_model) + results_without_ds = fit_eval_and_predict( + initial_weights, input_fn=input_fn, model_fn=self.get_model, + distribution=None, is_stateful_model=is_stateful_model) + + # First, special case, for multi-replica distributed training, batch norm + # is not aggregated globally. So it is expected to have different weights. + if (self.with_batch_norm and + distribution.num_replicas_in_sync > 1): + with self.assertRaises(AssertionError): + compare_results(results_with_ds, results_without_ds, distribution, + testcase=self) + else: + compare_results(results_with_ds, results_without_ds, distribution, + testcase=self) + + def run_dynamic_lr_test(self, distribution): + with self.cached_session(): + self.set_up_test_config() + self.skip_unsupported_test_configuration(distribution) + + x_train, y_train, _ = self.get_data() + model = self.get_model() + initial_weights = model.get_weights() + update_freq = None + + if (isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + # For TPUStrategy with steps_per_run > 1, the callback is not invoked + # every step. So, to compare the CPU/TPU, we let the CPU to behave the + # same as TPU. + update_freq = distribution.extended.steps_per_run + + def input_fn(dist): + """Generates training test given test configuration.""" + training_epochs = 2 + global_batch_size = 64 + batch_size = get_batch_size(global_batch_size, dist) + + training_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + 'epochs': training_epochs, + 'shuffle': False, + 'callbacks': [LearningRateBatchScheduler(update_freq)], + 'validation_data': (x_train, y_train) + } + # In this test case, we do not care eval and predict. + eval_inputs, predict_inputs = None, None + return training_inputs, eval_inputs, predict_inputs + + results_with_ds = fit_eval_and_predict( + initial_weights, input_fn=input_fn, model_fn=self.get_model, + distribution=distribution) + results_without_ds = fit_eval_and_predict( + initial_weights, input_fn=input_fn, model_fn=self.get_model, + distribution=None) + compare_results(results_with_ds, results_without_ds, distribution, + testcase=self) + + +class TestDistributionStrategyEmbeddingModelCorrectnessBase( + TestDistributionStrategyCorrectnessBase): + """Base class to test correctness of Keras models with embedding layers.""" + + def get_data(self, + count=(_GLOBAL_BATCH_SIZE * _EVAL_STEPS), + min_words=5, + max_words=10, + max_word_id=19, + num_classes=2): + distribution = [] + for _ in range(num_classes): + dist = np.abs(np.random.randn(max_word_id)) + dist /= np.sum(dist) + distribution.append(dist) + + features = [] + labels = [] + for _ in range(count): + label = np.random.randint(0, num_classes, size=1)[0] + num_words = np.random.randint(min_words, max_words, size=1)[0] + word_ids = np.random.choice( + max_word_id, size=num_words, replace=True, p=distribution[label]) + word_ids = word_ids + labels.append(label) + features.append(word_ids) + + features = keras.preprocessing.sequence.pad_sequences( + features, maxlen=max_words) + x_train = np.asarray(features, dtype=np.float32) + y_train = np.asarray(labels, dtype=np.int32).reshape((count, 1)) + x_predict = x_train[:_GLOBAL_BATCH_SIZE] + return x_train, y_train, x_predict + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py b/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..61202e30c4f33892d2675080fae07cc4d7102337 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py @@ -0,0 +1,173 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness tests for tf.keras DNN model using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import test +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.training import gradient_descent + + +def all_strategy_combinations_with_eager_and_graph_modes(): + return combinations.combine(distribution=keras_correctness_test_base. + all_strategies, + mode=['graph', 'eager']) + + +def all_strategy_combinations_with_graph_mode(): + return combinations.combine(distribution=keras_correctness_test_base. + all_strategies, mode=['graph']) + + +class TestDistributionStrategyDnnCorrectness( + keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): + + def get_model(self, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + # We add few non-linear layers to make it non-trivial. + model = keras.Sequential() + model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) + model.add(keras.layers.Dense( + 10, activation='relu', + kernel_regularizer=keras.regularizers.l2(1e-4))) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(1)) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent_keras.SGD(0.5), + metrics=['mse']) + return model + + def get_data(self): + # TODO(xiejw): Change this back to 10000, once we support final partial + # batch. + num_samples = 9984 + x_train = np.random.rand(num_samples, 1) + y_train = 3 * x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32) + return x_train, y_train, x_predict + + @combinations.generate(keras_correctness_test_base. + all_strategy_and_input_config_combinations()) + def test_dnn_correctness(self, distribution, use_numpy, use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + @combinations.generate(all_strategy_combinations_with_graph_mode()) + def test_dnn_with_dynamic_learning_rate(self, distribution): + self.run_dynamic_lr_test(distribution) + + +class TestDistributionStrategyDnnMetricCorrectness( + keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): + + def get_model(self, distribution=None): + with distribution.scope(): + model = keras.Sequential() + model.add(keras.layers.Dense(1, + input_shape=(1,), + kernel_initializer='ones')) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + metrics=[keras.metrics.BinaryAccuracy()]) + return model + + def run_metric_correctness_test(self, distribution): + with self.cached_session(): + self.set_up_test_config() + self.skip_unsupported_test_configuration(distribution) + + x_train, y_train, _ = self.get_data() + model = self.get_model(distribution=distribution) + + batch_size = 64 + batch_size = (keras_correctness_test_base. + get_batch_size(batch_size, distribution)) + train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + train_dataset = (keras_correctness_test_base. + batch_wrapper(train_dataset, batch_size, distribution)) + + history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) + self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) + + @combinations.generate(all_strategy_combinations_with_eager_and_graph_modes()) + def test_simple_dnn_metric_correctness(self, distribution): + self.run_metric_correctness_test(distribution) + + +class TestDistributionStrategyDnnMetricEvalCorrectness( + keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): + + def get_model(self, distribution=None): + with distribution.scope(): + model = keras.Sequential() + model.add( + keras.layers.Dense( + 3, activation='relu', input_dim=4, kernel_initializer='ones')) + model.add( + keras.layers.Dense( + 1, activation='sigmoid', kernel_initializer='ones')) + model.compile( + loss='mae', + metrics=['accuracy', keras.metrics.BinaryAccuracy()], + optimizer=gradient_descent.GradientDescentOptimizer(0.001)) + return model + + def run_eval_metrics_correctness_test(self, distribution): + with self.cached_session(): + self.set_up_test_config() + self.skip_unsupported_test_configuration(distribution) + + model = self.get_model(distribution=distribution) + + # verify correctness of stateful and stateless metrics. + x = np.ones((100, 4)).astype('float32') + y = np.ones((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = (keras_correctness_test_base. + batch_wrapper(dataset, 4, distribution)) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 1.) + self.assertEqual(outs[2], 1.) + + y = np.zeros((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = (keras_correctness_test_base. + batch_wrapper(dataset, 4, distribution)) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 0.) + self.assertEqual(outs[2], 0.) + + @combinations.generate(all_strategy_combinations_with_eager_and_graph_modes()) + def test_identity_model_metric_eval_correctness(self, distribution): + self.run_eval_metrics_correctness_test(distribution) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_embedding_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_embedding_model_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e881bb70ecc428e3f972cde5f19c1b61b1dc0f0b --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_embedding_model_correctness_test.py @@ -0,0 +1,150 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness test for tf.keras Embedding models using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.eager import test +from tensorflow.python.training import gradient_descent + + +class DistributionStrategyEmbeddingModelCorrectnessTest( + keras_correctness_test_base. + TestDistributionStrategyEmbeddingModelCorrectnessBase): + + def get_model(self, max_words=10, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + word_ids = keras.layers.Input( + shape=(max_words,), dtype=np.int32, name='words') + word_embed = keras.layers.Embedding(input_dim=20, + output_dim=10)(word_ids) + if self.use_distributed_dense: + word_embed = keras.layers.TimeDistributed(keras.layers.Dense(4))( + word_embed) + avg = keras.layers.GlobalAveragePooling1D()(word_embed) + preds = keras.layers.Dense(2, activation='softmax')(avg) + model = keras.Model(inputs=[word_ids], outputs=[preds]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + return model + + @combinations.generate(keras_correctness_test_base. + test_combinations_for_embedding_model()) + def test_embedding_model_correctness(self, distribution, use_numpy, + use_validation_data): + + self.use_distributed_dense = False + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + @combinations.generate(keras_correctness_test_base. + test_combinations_for_embedding_model()) + def test_embedding_time_distributed_model_correctness(self, + distribution, + use_numpy, + use_validation_data): + self.use_distributed_dense = True + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + +class DistributionStrategySiameseEmbeddingModelCorrectnessTest( + keras_correctness_test_base. + TestDistributionStrategyEmbeddingModelCorrectnessBase): + + def get_model(self, max_words=10, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + word_ids_a = keras.layers.Input( + shape=(max_words,), dtype=np.int32, name='words_a') + word_ids_b = keras.layers.Input( + shape=(max_words,), dtype=np.int32, name='words_b') + + def submodel(embedding, word_ids): + word_embed = embedding(word_ids) + rep = keras.layers.GlobalAveragePooling1D()(word_embed) + return keras.Model(inputs=[word_ids], outputs=[rep]) + + word_embed = keras.layers.Embedding( + input_dim=20, + output_dim=10, + input_length=max_words, + embeddings_initializer=keras.initializers.RandomUniform(0, 1)) + + a_rep = submodel(word_embed, word_ids_a).outputs[0] + b_rep = submodel(word_embed, word_ids_b).outputs[0] + sim = keras.layers.Dot(axes=1, normalize=True)([a_rep, b_rep]) + + model = keras.Model(inputs=[word_ids_a, word_ids_b], outputs=[sim]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='mse', + metrics=['mse']) + return model + + def get_data(self, + count=(keras_correctness_test_base._GLOBAL_BATCH_SIZE * + keras_correctness_test_base._EVAL_STEPS), + min_words=5, + max_words=10, + max_word_id=19, + num_classes=2): + features_a, labels_a, _ = (super( + DistributionStrategySiameseEmbeddingModelCorrectnessTest, self). + get_data(count, min_words, max_words, + max_word_id, num_classes)) + + features_b, labels_b, _ = (super( + DistributionStrategySiameseEmbeddingModelCorrectnessTest, self). + get_data(count, min_words, max_words, + max_word_id, num_classes)) + + y_train = np.zeros((count, 1), dtype=np.float32) + y_train[labels_a == labels_b] = 1.0 + y_train[labels_a != labels_b] = -1.0 + # TODO(b/123360757): Add tests for using list as inputs for multi-input + # models. + x_train = { + 'words_a': features_a, + 'words_b': features_b, + } + x_predict = x_train + + return x_train, y_train, x_predict + + @combinations.generate(keras_correctness_test_base. + test_combinations_for_embedding_model()) + def test_siamese_embedding_model_correctness(self, distribution, use_numpy, + use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3c2961456b2eede9570ce29f7a8900834f2ccfb7 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py @@ -0,0 +1,93 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness tests for tf.keras CNN models using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.eager import test +from tensorflow.python.keras.optimizer_v2 import gradient_descent + + +class DistributionStrategyCnnCorrectnessTest( + keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): + + def get_model(self, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + image = keras.layers.Input(shape=(28, 28, 3), name='image') + c1 = keras.layers.Conv2D( + name='conv1', filters=16, kernel_size=(3, 3), strides=(4, 4), + kernel_regularizer=keras.regularizers.l2(1e-4))( + image) + if self.with_batch_norm: + c1 = keras.layers.BatchNormalization(name='bn1')(c1) + c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1) + logits = keras.layers.Dense( + 10, activation='softmax', name='pred')( + keras.layers.Flatten()(c1)) + model = keras.Model(inputs=[image], outputs=[logits]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.SGD( + learning_rate=0.1), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + + return model + + def get_data(self, + count=keras_correctness_test_base._GLOBAL_BATCH_SIZE + * keras_correctness_test_base._EVAL_STEPS, + shape=(28, 28, 3), + num_classes=10): + centers = np.random.randn(num_classes, *shape) + + features = [] + labels = [] + for _ in range(count): + label = np.random.randint(0, num_classes, size=1)[0] + offset = np.random.normal(loc=0, scale=0.1, size=np.prod(shape)) + offset = offset.reshape(shape) + labels.append(label) + features.append(centers[label] + offset) + + x_train = np.asarray(features, dtype=np.float32) + y_train = np.asarray(labels, dtype=np.float32).reshape((count, 1)) + x_predict = x_train + return x_train, y_train, x_predict + + @combinations.generate(keras_correctness_test_base. + all_strategy_and_input_config_combinations()) + def test_cnn_correctness(self, distribution, use_numpy, use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + @combinations.generate(keras_correctness_test_base. + all_strategy_and_input_config_combinations()) + def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy, + use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data, + with_batch_norm=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_lstm_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_lstm_model_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed2dfa206cdf4be24a88b1d54090487c1873399 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_lstm_model_correctness_test.py @@ -0,0 +1,65 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness tests for tf.keras LSTM model using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.eager import test +from tensorflow.python.training import gradient_descent + + +class DistributionStrategyLstmModelCorrectnessTest( + keras_correctness_test_base. + TestDistributionStrategyEmbeddingModelCorrectnessBase): + + def get_model(self, max_words=10, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + word_ids = keras.layers.Input( + shape=(max_words,), dtype=np.int32, name='words') + word_embed = keras.layers.Embedding(input_dim=20, + output_dim=10)(word_ids) + lstm_embed = keras.layers.LSTM(units=4, + return_sequences=False)(word_embed) + + preds = keras.layers.Dense(2, activation='softmax')(lstm_embed) + model = keras.Model(inputs=[word_ids], outputs=[preds]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + return model + + @combinations.generate(keras_correctness_test_base. + test_combinations_for_embedding_model()) + def test_lstm_model_correctness(self, + distribution, + use_numpy, + use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py index 6dfd85bcc4f3784e2744fd876a7190cc9581d96a..c93d7afa7ceef2c9c272e91997e2871655cea079 100644 --- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -18,24 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import shutil -import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib.distribute.python import combinations -from tensorflow.core.protobuf import config_pb2 from tensorflow.python import keras -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context -from tensorflow.python.estimator import run_config -from tensorflow.python.estimator import training -from tensorflow.python.estimator.canned import dnn_linear_combined -from tensorflow.python.estimator.canned import prediction_keys -from tensorflow.python.estimator.export import export -from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column_lib as feature_column +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -44,103 +33,7 @@ from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile from tensorflow.python.platform import test -from tensorflow.python.summary.writer import writer_cache - - -class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def dataset_input_fn(self, x, y, batch_size): - - def input_fn(): - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) - dataset = dataset.repeat(1).batch(batch_size) - return dataset - - return input_fn - - @combinations.generate( - combinations.combine( - mode=['graph'], - distribution=[ - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus - ], - use_train_and_evaluate=[True, False])) - def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): - label_dimension = 2 - input_dimension = label_dimension - batch_size = 10 - data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) - data = data.reshape(batch_size, label_dimension) - train_input_fn = self.dataset_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size // distribution.num_replicas_in_sync) - eval_input_fn = self.dataset_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size // distribution.num_replicas_in_sync) - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, batch_size=batch_size, shuffle=False) - - linear_feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,)) - ] - dnn_feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,)) - ] - feature_columns = linear_feature_columns + dnn_feature_columns - session_config = config_pb2.ConfigProto( - log_device_placement=True, allow_soft_placement=True) - estimator = dnn_linear_combined.DNNLinearCombinedRegressor( - linear_feature_columns=linear_feature_columns, - dnn_hidden_units=(2, 2), - dnn_feature_columns=dnn_feature_columns, - label_dimension=label_dimension, - model_dir=self._model_dir, - dnn_optimizer=adam.Adam(0.001), - linear_optimizer=adam.Adam(0.001), - config=run_config.RunConfig( - train_distribute=distribution, - eval_distribute=distribution, - session_config=session_config)) - - num_steps = 2 - if use_train_and_evaluate: - scores, _ = training.train_and_evaluate( - estimator, training.TrainSpec(train_input_fn, max_steps=num_steps), - training.EvalSpec(eval_input_fn)) - else: - estimator.train(train_input_fn, steps=num_steps) - scores = estimator.evaluate(eval_input_fn) - - self.assertIn('loss', six.iterkeys(scores)) - - predictions = np.array([ - x[prediction_keys.PredictionKeys.PREDICTIONS] - for x in estimator.predict(predict_input_fn) - ]) - self.assertAllEqual((batch_size, label_dimension), predictions.shape) - - feature_spec = feature_column.make_parse_example_spec(feature_columns) - serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( - feature_spec) - export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), - serving_input_receiver_fn) - self.assertTrue(gfile.Exists(export_dir)) - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) def get_model(): @@ -152,113 +45,80 @@ def get_model(): class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph'])) + @combinations.generate( + combinations.combine( + distribution=[ + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus, + combinations.parameter_server_strategy_with_two_gpus, + ], + mode=['graph', 'eager'])) def testKerasOptimizerWithUnequalInput(self, distribution): - def create_fn(): + with distribution.scope(): var = variables.Variable( 2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM) - # grad for cpu is 1, grad for gpu is 2, avg grad is 1.5. - loss = math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) * var optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2) - train_op = optimizer.minimize(loss, var_list=[var]) - m = optimizer.get_slot(var, 'm') - v = optimizer.get_slot(var, 'v') - return (var, m, v, train_op, optimizer.iterations) + all_vars = [] - devices = ['/device:GPU:0', '/device:CPU:0'] - with distribution.scope(): - (var, m, v, op, counter) = distribution.call_for_each_replica(create_fn) + def model_fn(): + + def loss_fn(): + replica_id = _replica_id() + return math_ops.cast(replica_id + 1, dtype=dtypes.float32) * 0.5 * var + + train_op = optimizer.minimize(loss_fn, var_list=[var]) + + return train_op, optimizer + + def train_fn(): + train_op, optimizer = distribution.extended.call_for_each_replica( + model_fn) + if not all_vars: + all_vars.append(var) + all_vars.append(optimizer.get_slot(var, 'm')) + all_vars.append(optimizer.get_slot(var, 'v')) + return distribution.group(train_op) + + if not context.executing_eagerly(): + with self.cached_session() as sess: + train_fn = sess.make_callable(train_fn()) self.evaluate(variables.global_variables_initializer()) - var_val = [2.0, 2.0, 2.0] - self.assertAllClose( - var_val, - self.evaluate( - [distribution.read_var(var), - var.get(devices[0]), - var.get(devices[1])])) - self.assertAllClose([0, 0, 0], - self.evaluate([ - distribution.read_var(counter), - counter.get(devices[0]), - counter.get(devices[1]) - ])) - train_op = distribution.unwrap(op) - self.evaluate(train_op) - # m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2 - m_val = [1.2, 1.2, 1.2] - # assert slot variables in both replicas are the same. - self.assertAllClose( - m_val, - self.evaluate( - [distribution.read_var(m), - m.get(devices[0]), - m.get(devices[1])])) - # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25 - v_val = [1.8, 1.8, 1.8] - self.assertAllClose( - v_val, - self.evaluate( - [distribution.read_var(v), - v.get(devices[0]), - v.get(devices[1])])) + # first step. + train_fn() # var(1) = var(0) - lr * m(1) * sqrt(1 - beta2) / sqrt(v(1)) / (1 - beta1) # = 2.0 - 0.01 * 1.2 * sqrt(0.8) / sqrt(1.8) / 0.8 - var_val = [1.99, 1.99, 1.99] - self.assertAllClose( - var_val, - self.evaluate( - [distribution.read_var(var), - var.get(devices[0]), - var.get(devices[1])])) - self.assertAllClose([1, 1, 1], - self.evaluate([ - distribution.read_var(counter), - counter.get(devices[0]), - counter.get(devices[1]) - ])) + self.assertAllClose(1.99, self.evaluate(all_vars[0])) + # m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2 + self.assertAllClose(1.2, self.evaluate(all_vars[1])) + # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25 + self.assertAllClose(1.8, self.evaluate(all_vars[2])) - self.evaluate(train_op) + # second step. + train_fn() + # var(1) = var(0) - lr * 2 = 1.98 + self.assertAllClose(1.98, self.evaluate(all_vars[0])) # m(2) = beta1 * m(1) + (1-beta1) * grad = 0.2 * 1.2 + 0.8 * 1.5 - m_val = [1.44, 1.44, 1.44] - self.assertAllClose( - m_val, - self.evaluate( - [distribution.read_var(m), - m.get(devices[0]), - m.get(devices[1])])) + self.assertAllClose(1.44, self.evaluate(all_vars[1])) # v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25 - v_val = [2.16, 2.16, 2.16] - self.assertAllClose( - v_val, - self.evaluate( - [distribution.read_var(v), - v.get(devices[0]), - v.get(devices[1])])) - self.assertAllClose([2, 2, 2], - self.evaluate([ - distribution.read_var(counter), - counter.get(devices[0]), - counter.get(devices[1]) - ])) + self.assertAllClose(2.16, self.evaluate(all_vars[2])) - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph'])) + @combinations.generate( + combinations.combine( + distribution=[ + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.parameter_server_strategy_with_two_gpus, + ], + mode=['graph', 'eager'])) def testOptimizerWithKerasModelAndNumpyArrays(self, distribution): with self.cached_session(): - model = get_model() - optimizer = gradient_descent.SGD(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.SGD(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) inputs = np.zeros((64, 3), dtype=np.float32) targets = np.zeros((64, 4), dtype=np.float32) diff --git a/tensorflow/contrib/distribute/python/keras_stateful_lstm_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_stateful_lstm_model_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f5faf6c36b880a72bafc8d082cff2816f3b11a76 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_stateful_lstm_model_correctness_test.py @@ -0,0 +1,99 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for stateful tf.keras LSTM models using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.eager import test +from tensorflow.python.training import gradient_descent + + +def strategies_for_stateful_embedding_model(): + """Returns TPUStrategy with single core device assignment.""" + + return [combinations.tpu_strategy_one_core, + combinations.tpu_strategy_one_step_one_core] + + +def test_combinations_for_stateful_embedding_model(): + return ( + combinations.combine( + distribution=strategies_for_stateful_embedding_model(), + mode='graph', + use_numpy=False, + use_validation_data=False + )) + + +class DistributionStrategyStatefulLstmModelCorrectnessTest( + keras_correctness_test_base. + TestDistributionStrategyEmbeddingModelCorrectnessBase): + + def get_model(self, max_words=10, initial_weights=None, distribution=None): + batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE + + with keras_correctness_test_base.MaybeDistributionScope(distribution): + word_ids = keras.layers.Input( + shape=(max_words,), + batch_size=batch_size, + dtype=np.int32, name='words') + word_embed = keras.layers.Embedding(input_dim=20, + output_dim=10)(word_ids) + lstm_embed = keras.layers.LSTM(units=4, + return_sequences=False, + stateful=True)(word_embed) + + preds = keras.layers.Dense(2, activation='softmax')(lstm_embed) + model = keras.Model(inputs=[word_ids], outputs=[preds]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + return model + + @combinations.generate(test_combinations_for_stateful_embedding_model()) + def test_stateful_lstm_model_correctness(self, + distribution, + use_numpy, + use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data, + is_stateful_model=True) + + @combinations.generate(keras_correctness_test_base. + test_combinations_with_tpu_strategies()) + def test_incorrectly_use_multiple_cores_for_stateful_lstm_model( + self, distribution, use_numpy, use_validation_data): + with self.assertRaisesRegexp(ValueError, + 'Single core must be used for computation ' + 'on stateful models. Consider adding ' + '`device_assignment` parameter to ' + 'TPUStrategy'): + self.run_correctness_test(distribution, use_numpy, use_validation_data, + is_stateful_model=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 683cc89bfbae9c877ea6794d311ffc00c96c6937..77e241974f7c4c27382ab548a202891fdbbc6ba0 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -25,18 +25,17 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.python import keras +from tensorflow.python.data.experimental.ops import cardinality from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import values from tensorflow.python.eager import test from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.summary.writer import writer_cache @@ -48,6 +47,9 @@ _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) _NUM_CLASS = 2 +# Note: Please make sure the tests in this file are also covered in +# keras_backward_compat_test for features that are supported with both APIs. + # TODO(anjalisridhar): Add a decorator that will allow us to run these tests as # part of the tf.keras unit tests suite. @@ -68,6 +70,32 @@ def simple_functional_model(): return model +def simple_subclassed_model(num_labels=_NUM_CLASS): + + class _SimpleMLP(keras.Model): + + def __init__(self, num_labels): + super(_SimpleMLP, self).__init__() + self.dense = keras.layers.Dense(num_labels) + + def call(self, inputs): + return self.dense(inputs) + + return _SimpleMLP(num_labels) + + +def simple_multi_inputs_multi_outputs_model(): + input_a = keras.layers.Input(shape=(16,), name='input_a') + input_b = keras.layers.Input(shape=(16,), name='input_b') + + merged = keras.layers.concatenate([input_a, input_b], name='merge') + output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged) + output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged) + model = keras.models.Model( + inputs=[input_a, input_b], outputs=[output_c, output_d]) + return model + + def multi_inputs_multi_outputs_model(): input_a = keras.layers.Input(shape=(16,), name='input_a') input_b = keras.layers.Input(shape=(16,), name='input_b') @@ -200,6 +228,22 @@ def get_predict_dataset(distribution): return dataset +def convert_numpy_to_dataset_with_unknown_cardinality(inputs, + targets=None): + if targets is not None: + input_slices = (inputs, targets) + dummy_op = (lambda inp, target: True) + else: + input_slices = inputs + dummy_op = (lambda inp: True) + + original_dataset = (dataset_ops.Dataset.from_tensor_slices( + input_slices)) + ds_with_unknown_cardinality = (original_dataset.filter(dummy_op). + batch(10, drop_remainder=True)) + return ds_with_unknown_cardinality + + def multi_input_output_model(): a = keras.layers.Input(shape=(3,), name='input_a') b = keras.layers.Input(shape=(5,), name='input_b') @@ -214,90 +258,12 @@ def multi_input_output_model(): return model -def get_correctness_test_inputs(use_numpy, use_validation_data, - with_distribution, - x_train, y_train, x_predict): - """Generates the inputs for correctness check when enable Keras with DS.""" - training_epochs = 2 - global_batch_size = 64 - batch_size = global_batch_size - # TODO(b/118776054): Use global batch size for Keras/DS support. - use_per_core_batch_size = ( - with_distribution and - not distributed_training_utils.global_batch_size_supported( - with_distribution)) - if use_per_core_batch_size: - batch_size //= with_distribution.num_replicas_in_sync - - if use_numpy: - training_inputs = { - 'batch_size': batch_size, - 'x': x_train, - 'y': y_train, - 'epochs': training_epochs, - 'shuffle': False, - } - - if use_validation_data: - eval_inputs = None - training_inputs['validation_data'] = (x_train, y_train) - else: - eval_inputs = { - 'batch_size': batch_size, - 'x': x_train, - 'y': y_train, - } - predict_inputs = { - 'x': np.array(x_predict, dtype=np.float32), - } - else: - # For dataset inputs, we do not pass batch_size to - # keras.fit/evaluate/predict. The batch size is part of the dataset. - train_dataset = dataset_ops.Dataset.from_tensor_slices( - (x_train, y_train)) - x = batch_wrapper( - train_dataset, batch_size, with_distribution, repeat=training_epochs) - - training_inputs = { - 'batch_size': None, - 'x': x, - 'y': None, - 'epochs': training_epochs, - 'shuffle': False, - 'steps_per_epoch': len(x_train) // global_batch_size, - } - if use_validation_data: - eval_inputs = None # Remove the eval_inputs - eval_dataset = dataset_ops.Dataset.from_tensor_slices( - (x_train, y_train)) - x = batch_wrapper(eval_dataset, batch_size, with_distribution) - training_inputs['validation_data'] = x - training_inputs['validation_steps'] = 5 - else: - eval_inputs = { - 'batch_size': None, - 'x': x, - 'y': None, - 'steps': 20, - } - - predict_batch_size = len(x_predict) - if use_per_core_batch_size: - predict_batch_size //= with_distribution.num_replicas_in_sync - predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) - predict_dataset = batch_wrapper(predict_dataset, - predict_batch_size, with_distribution) - predict_inputs = { - 'steps': 1, - 'x': predict_dataset, - } - - return training_inputs, eval_inputs, predict_inputs - - +# TODO(josh11b): Add combinations.one_device_strategy_gpu once it works with +# TestDistributionStrategyWithCallbacks.test_callbacks_in_predict. strategies_minus_tpu = [ combinations.default_strategy, combinations.one_device_strategy, + combinations.one_device_strategy_gpu, combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, combinations.core_mirrored_strategy_with_gpu_and_cpu, @@ -309,53 +275,45 @@ tpu_strategies = [ def strategy_minus_tpu_combinations(): - return combinations.combine( - distribution=strategies_minus_tpu, - mode=['graph', 'eager']) + return combinations.combine(distribution=strategies_minus_tpu, + mode=['graph', 'eager']) def tpu_strategy_combinations(): - return combinations.combine( - distribution=tpu_strategies, - mode=['graph']) + return combinations.combine(distribution=tpu_strategies, + mode=['graph']) def all_strategy_combinations(): return strategy_minus_tpu_combinations() + tpu_strategy_combinations() -# TODO(priyag): Add v2 optimizers here. +def all_strategy_combinations_minus_default(): + strategy_minus_default_combinations = combinations.combine( + distribution=[ + combinations.one_device_strategy, + combinations.one_device_strategy_gpu, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager']) + return strategy_minus_default_combinations + tpu_strategy_combinations() + + def strategy_and_optimizer_combinations(): return combinations.times( all_strategy_combinations(), - combinations.combine( - optimizer=[combinations.adagrad_optimizer_v1_fn, - combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn, - combinations.rmsprop_optimizer_v1_fn])) - - -def strategy_and_input_combinations(): - return ( - combinations.times( - combinations.combine(distribution=strategies_minus_tpu), - combinations.combine(mode=['graph'], - use_numpy=[True, False], - use_validation_data=[True, False]) - + combinations.combine(mode=['eager'], - use_numpy=[False], - use_validation_data=[False])) + - combinations.times( - combinations.combine(distribution=tpu_strategies), - combinations.combine(mode=['graph'], - use_numpy=[True, False], - use_validation_data=[True, False]))) - - -def strategy_for_numpy_input_combinations(): - return combinations.combine( - distribution=strategies_minus_tpu + tpu_strategies, - mode=['graph']) + combinations.combine(optimizer=[ + combinations.adagrad_optimizer_v1_fn, + combinations.adagrad_optimizer_keras_v2_fn, + combinations.adam_optimizer_v1_fn, + combinations.adam_optimizer_keras_v2_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_keras_v2_fn, + combinations.rmsprop_optimizer_v1_fn, + combinations.rmsprop_optimizer_keras_v2_fn + ])) class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @@ -375,7 +333,9 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=['graph'])) def test_train_functional_with_distribution_strategy(self, distribution): @@ -403,7 +363,9 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=['graph'])) def test_train_sequential_with_distribution_strategy(self, distribution): @@ -430,8 +392,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph'])) def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self, distribution): train_data, test_data = get_multi_inputs_multi_outputs_data() @@ -482,8 +444,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph'])) def test_keras_optimizer_with_distribution_strategy(self, distribution): keras_model = simple_sequential_model() @@ -509,16 +471,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_for_numpy_input_combinations()) - def test_creating_var_with_numpy_arrays(self, distribution): - with self.cached_session(): - x = np.asarray(np.random.random((64, 3)), dtype=np.float32) - var_x = distributed_training_utils.get_var_for_numpy(distribution, x) - val = self.evaluate(var_x.value()) - # Verify that the numpy value is copied to the variable. - self.assertAllEqual(x, val) - - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calculating_input_params_no_steps_no_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies # that use per_core_batch_size @@ -549,7 +502,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_63_samples, steps=None, batch_size=None) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calculating_input_params_with_steps_no_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies @@ -595,7 +548,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_63_samples, steps=1, batch_size=None) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calculating_input_params_no_steps_with_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies @@ -629,7 +582,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_64_samples, steps=None, batch_size=3) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calculating_input_params_with_steps_with_batch_size(self, distribution): with self.cached_session(): @@ -646,45 +599,46 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_64_samples, steps=10, batch_size=13) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calling_model_with_numpy_arrays(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - - inputs = np.zeros((64, 3), dtype=np.float32) - targets = np.zeros((64, 4), dtype=np.float32) - - # Call fit with validation data - model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, - validation_data=(inputs, targets)) - - # TODO(anjalisridhar): We need tests for when the batch size and steps are - # smaller and results in a 0 batch_size and steps value. - model.evaluate(inputs, targets) - # with steps - model.evaluate(inputs, targets, steps=2) - # with batch_size - model.evaluate(inputs, targets, batch_size=8) - - model.predict(inputs) - # with steps - model.predict(inputs, steps=2) - # with batch_size - model.predict(inputs, batch_size=8) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, + validation_data=(inputs, targets)) + + # TODO(anjalisridhar): We need tests for when the batch size and steps + # are smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) + + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calling_model_with_nested_numpy_arrays(self, distribution): with self.cached_session(): - model = multi_input_output_model() - - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = multi_input_output_model() + optimizer = gradient_descent.GradientDescentOptimizer( + learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32) input_b_np = np.asarray(np.random.random((64, 5)), dtype=np.float32) @@ -714,26 +668,29 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, @combinations.generate(combinations.combine( distribution=strategies_minus_tpu, mode=['graph'])) def test_numpy_with_sample_weights(self, distribution): - model = get_model() - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with self.cached_session(): + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) - inputs = np.zeros((20, 3), np.float32) - targets = np.zeros((20, 4), np.float32) - sample_weights = np.ones((20), np.float32) + inputs = np.zeros((20, 3), np.float32) + targets = np.zeros((20, 4), np.float32) + sample_weights = np.ones((20), np.float32) - model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, - steps_per_epoch=2, verbose=1) + model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, + steps_per_epoch=2, verbose=1) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_flatten_predict_outputs(self, distribution): with self.cached_session(): - model = multi_input_output_model() - - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = multi_input_output_model() + optimizer = gradient_descent.GradientDescentOptimizer( + learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) # We take 6 input samples with each input having a dimension of 3 or 5. input_a_np = np.asarray(np.random.random((6, 3)), dtype=np.float32) @@ -750,6 +707,61 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, self.assertAllEqual([6, 7], outs[0].shape) self.assertAllEqual([6, 7], outs[1].shape) + @combinations.generate(tpu_strategy_combinations()) + def test_predict_with_partial_batch(self, distribution): + with self.cached_session(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + + with distribution.scope(): + model_with_ds_strategy = get_model() + model_with_ds_strategy.compile(optimizer, loss) + + cpu_model = get_model() + cpu_model.compile(optimizer, loss) + + inputs = np.zeros((10, 3), dtype=np.float32) + + # As sample size is 10, we batch by 4 so that the last batch is + # a partial batch. Also `fit()` using numpy array as inputs without + # distribution strategy uses entire sample as a single batch. As so, + # we remove parameters `batch_size` and `steps`. + cpu_model.set_weights(model_with_ds_strategy.get_weights()) + self.assertAllClose( + model_with_ds_strategy.predict(inputs, batch_size=4, steps=3), + cpu_model.predict(inputs), + atol=1e-5, rtol=1e-5) + + @combinations.generate(tpu_strategy_combinations()) + def test_predict_multi_output_model_with_partial_batch( + self, distribution): + with self.cached_session(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + + with distribution.scope(): + model_with_ds_strategy = simple_multi_inputs_multi_outputs_model() + model_with_ds_strategy.compile(optimizer, loss) + + cpu_model = simple_multi_inputs_multi_outputs_model() + cpu_model.compile(optimizer, loss) + + input_data, _ = get_multi_inputs_multi_outputs_data() + input_dict = { + 'input_a': input_data['input_a'], + 'input_b': input_data['input_b'], + } + + # As sample size is 200, we batch by 18 so that the last batch is + # a partial batch. Also `fit()` using numpy array as inputs without + # distribution strategy uses entire sample as a single batch. As so, + # we remove parameters `batch_size` and `steps`. + cpu_model.set_weights(model_with_ds_strategy.get_weights()) + self.assertAllClose( + model_with_ds_strategy.predict(input_dict, batch_size=18, steps=12), + cpu_model.predict(input_dict), + atol=1e-4, rtol=1e-4) + class TestDistributionStrategyWithDatasets(test.TestCase, parameterized.TestCase): @@ -757,12 +769,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(all_strategy_combinations()) def test_calling_model_on_same_dataset(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) dataset = get_dataset(distribution) @@ -776,20 +788,19 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(all_strategy_combinations()) def test_model_interleaved_eval_same_as_direct_eval(self, distribution): with self.cached_session(): - user_controlled_model = get_model() - user_controlled_model.compile( - gradient_descent.GradientDescentOptimizer(0.001), - loss='mse', - metrics=['mae', keras.metrics.CategoricalAccuracy()], - distribute=distribution) - - interleaved_model = get_model() - interleaved_model.set_weights(user_controlled_model.get_weights()) - interleaved_model.compile( - gradient_descent.GradientDescentOptimizer(0.001), - loss='mse', - metrics=['mae', keras.metrics.CategoricalAccuracy()], - distribute=distribution) + with distribution.scope(): + user_controlled_model = get_model() + user_controlled_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()]) + + interleaved_model = get_model() + interleaved_model.set_weights(user_controlled_model.get_weights()) + interleaved_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()]) dataset = get_dataset(distribution) @@ -824,12 +835,13 @@ class TestDistributionStrategyWithDatasets(test.TestCase, mode=['graph', 'eager'])) def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution): with self.cached_session(): - model = multi_input_output_model() - - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) - loss = 'mse' - metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = multi_input_output_model() + optimizer = gradient_descent.GradientDescentOptimizer( + learning_rate=0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) input_a_np = np.random.random((10, 3)) input_b_np = np.random.random((10, 5)) @@ -854,14 +866,103 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) @combinations.generate(all_strategy_combinations()) - def test_fit_eval_and_predict_methods_on_dataset(self, distribution): + def test_fit_eval_and_predict_methods_on_dataset_without_steps( + self, distribution): with self.cached_session(): - model = get_model() + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((1000, 3), dtype=np.float32) + targets = np.zeros((1000, 4), dtype=np.float32) + # steps/steps_per_epoch are calculated when using numpy arrays as + # input data. + fit_with_numpy = model.fit(inputs, targets, epochs=1, + batch_size=10).history + eval_with_numpy = model.evaluate(inputs, targets, batch_size=10) + predict_with_numpy = model.predict(inputs, batch_size=10) - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.batch(10, drop_remainder=True) + fit_with_ds = model.fit(dataset, epochs=1).history + eval_with_ds = model.evaluate(dataset) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + predict_dataset = predict_dataset.batch(10, drop_remainder=True) + predict_with_ds = model.predict(predict_dataset) + self.assertAllClose( + fit_with_numpy, fit_with_ds, atol=1e-4, rtol=1e-4) + self.assertAllClose( + eval_with_numpy, eval_with_ds, atol=1e-4, rtol=1e-4) + self.assertAllClose( + predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4) + + @combinations.generate(all_strategy_combinations()) + def test_on_dataset_with_unknown_cardinality_without_steps( + self, distribution): + with self.cached_session(): + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((1000, 3), dtype=np.float32) + targets = np.zeros((1000, 4), dtype=np.float32) + # steps/steps_per_epoch are calculated when using numpy arrays as + # input data. + fit_with_numpy = model.fit(inputs, targets, epochs=1, + batch_size=10).history + fit_with_numpy_multiple_epochs = model.fit( + inputs, targets, epochs=2, batch_size=10).history + eval_with_numpy = model.evaluate(inputs, targets, batch_size=10) + predict_with_numpy = model.predict(inputs, batch_size=10) + + dataset = convert_numpy_to_dataset_with_unknown_cardinality( + inputs, targets) + predict_dataset = convert_numpy_to_dataset_with_unknown_cardinality( + inputs) + + self.assertEqual(keras.backend.get_value(cardinality.cardinality( + dataset)), cardinality.UNKNOWN) + self.assertEqual(keras.backend.get_value(cardinality.cardinality( + predict_dataset)), cardinality.UNKNOWN) + + eval_with_ds = model.evaluate(dataset) + predict_with_ds = model.predict(predict_dataset) + self.assertAllClose( + eval_with_numpy, eval_with_ds, atol=1e-4, rtol=1e-4) + self.assertAllClose( + predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4) + + if (distributed_training_utils.is_tpu_strategy(distribution) and + distribution.extended.steps_per_run != 1): + with self.assertRaisesRegexp(ValueError, '`steps_per_epoch` ' + 'should be specified'): + fit_with_ds = model.fit(dataset, epochs=1) + else: + fit_with_ds = model.fit(dataset, + epochs=1).history + fit_with_ds_multiple_epochs = model.fit(dataset, + epochs=2).history + self.assertAllClose( + fit_with_numpy, fit_with_ds, atol=1e-4, rtol=1e-4) + self.assertAllClose( + fit_with_numpy_multiple_epochs, + fit_with_ds_multiple_epochs, atol=1e-4, rtol=1e-4) + + @combinations.generate(all_strategy_combinations()) + def test_fit_eval_and_predict_methods_on_dataset(self, distribution): + with self.cached_session(): + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) dataset = get_dataset(distribution) @@ -872,10 +973,10 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(strategy_and_optimizer_combinations()) def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer): with self.cached_session(): - model = get_model() - - loss = 'mse' - model.compile(optimizer(), loss, distribute=distribution) + with distribution.scope(): + model = get_model() + loss = 'mse' + model.compile(optimizer(), loss) dataset = get_dataset(distribution) @@ -885,35 +986,39 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(strategy_minus_tpu_combinations()) def test_dataset_with_sample_weights(self, distribution): - model = get_model() - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) - - inputs = np.zeros((10, 3), np.float32) - targets = np.zeros((10, 4), np.float32) - sample_weights = np.ones((10), np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets, - sample_weights)) - dataset = dataset.repeat() - dataset = dataset.batch(10) - - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) - model.evaluate(dataset, steps=2, verbose=1) - model.predict(dataset, steps=2) + with self.cached_session(): + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) + + inputs = np.zeros((10, 3), np.float32) + targets = np.zeros((10, 4), np.float32) + sample_weights = np.ones((10), np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets, + sample_weights)) + dataset = dataset.repeat() + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph', 'eager'])) - def test_dataset_wrong_input_shape(self, distribution): + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_wrong_input_shape(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) # Wrong input shape inputs = np.zeros((10, 5), dtype=np.float32) @@ -927,15 +1032,17 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) @combinations.generate(combinations.combine( - distribution=[combinations.mirrored_strategy_with_two_gpus], + distribution=[combinations.mirrored_strategy_with_gpu_and_cpu], mode=['graph', 'eager'])) - def test_dataset_no_batch_input_validation(self, distribution): + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_no_batch_input_validation(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) # User forgets to batch the dataset inputs = np.zeros((10, 3), dtype=np.float32) @@ -951,11 +1058,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase, mode=['graph'])) def test_dataset_input_shape_fully_defined(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) dataset = get_dataset(distribution) # Input shapes are not fully known. Batch dimension is unknown as we are @@ -967,7 +1074,9 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=['graph', 'eager'])) def test_learning_phase_value(self, distribution): @@ -975,16 +1084,17 @@ class TestDistributionStrategyWithDatasets(test.TestCase, # meaningful values. Currently we don't pass the learning phase if the # Lambda layer uses the learning phase. with self.cached_session(): - x = keras.layers.Input(shape=(1,), name='input') - y = keras.layers.Dense(1, kernel_initializer='ones')(x) - z = keras.layers.Dropout(0.9999)(y) - model = keras.Model(x, z) - initial_weights = model.get_weights() + with distribution.scope(): + x = keras.layers.Input(shape=(1,), name='input') + y = keras.layers.Dense(1, kernel_initializer='ones')(x) + z = keras.layers.Dropout(0.9999)(y) + model = keras.Model(x, z) + initial_weights = model.get_weights() - optimizer = gradient_descent.GradientDescentOptimizer(0.005) - loss = 'mse' - metrics = ['acc'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + optimizer = gradient_descent.GradientDescentOptimizer(0.005) + loss = 'mse' + metrics = ['acc'] + model.compile(optimizer, loss, metrics=metrics) batch_size = 8 if isinstance(distribution, mirrored_strategy.CoreMirroredStrategy): @@ -998,7 +1108,8 @@ class TestDistributionStrategyWithDatasets(test.TestCase, hist = model.fit(dataset, epochs=1, steps_per_epoch=20, verbose=1) self.assertAlmostEqual(hist.history['acc'][0], 0, 0) - model.set_weights(initial_weights) + with distribution.scope(): + model.set_weights(initial_weights) # TODO(psv/anjalisridhar): Enable these lines after we fix b/117431185. # evaluate_output = model.evaluate(dataset, steps=20) # self.assertAlmostEqual(evaluate_output[1], 1, 0) @@ -1012,14 +1123,14 @@ class TestDistributionStrategyWithDatasets(test.TestCase, ref_output = np.ones((160, 1), dtype=np.float32) self.assertArrayNear(output, ref_output, 1e-1) - @combinations.generate(strategy_minus_tpu_combinations()) + @combinations.generate(all_strategy_combinations()) def testOptimizerWithCallbacks(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = gradient_descent_keras.SGD(0.01) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent_keras.SGD(0.01) + loss = 'mse' + model.compile(optimizer, loss) dataset = get_dataset(distribution) @@ -1028,375 +1139,187 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) - grouped_models = distribution.unwrap(model._grouped_model) - with distribution.scope(): - for m in grouped_models: - self.assertAllClose(0.001, keras.backend.get_value( - m.optimizer.lr), atol=1e-05, rtol=1e-05) - - -class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): + self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr)) - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_validating_dataset_input_tensors_with_shape_mismatch(self, - distribution): + @combinations.generate(tpu_strategy_combinations()) + def test_predict_with_dataset_with_partial_batch(self, distribution): with self.cached_session(): - a = constant_op.constant([1, 2], shape=(1, 2)) - b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) - x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) - y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with distribution.scope(): - # Removed device and input tensor shape details from the error message - # since the order of the device and the corresponding input tensor shape - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor shapes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): - distributed_training_utils.validate_distributed_dataset_inputs( - distribution, x, y) + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_validating_dataset_input_tensors_with_dtype_mismatch(self, - distribution): - with self.cached_session(): - a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) - b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) - x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) - y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) with distribution.scope(): - # Removed device and input tensor dtype details from the error message - # since the order of the device and the corresponding input tensor dtype - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor dtypes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): - distributed_training_utils.validate_distributed_dataset_inputs( - distribution, x, y) - - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_unsupported_features(self, distribution): - with self.cached_session(): - model = get_model() + model_with_ds_strategy = get_model() + model_with_ds_strategy.compile(optimizer, loss) - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + cpu_model = get_model() + cpu_model.compile(optimizer, loss) - dataset = get_dataset(distribution) + inputs = np.zeros((10, 3), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs)) - # Test with validation split - with self.assertRaisesRegexp( - ValueError, '`validation_split` argument is not ' - 'supported when input `x` is a dataset or a ' - 'dataset iterator.+'): - model.fit(dataset, - epochs=1, steps_per_epoch=2, verbose=0, - validation_split=0.5, validation_steps=2) - - # Test with sample weight. - sample_weight = np.random.random((10,)) - with self.assertRaisesRegexp( - ValueError, '`sample_weight` argument is not supported when input ' - '`x` is a dataset or a dataset iterator.'): - model.fit( - dataset, - epochs=1, - steps_per_epoch=2, - verbose=0, - sample_weight=sample_weight) - - # Test with not specifying the `steps` argument. - with self.assertRaisesRegexp( - ValueError, 'you should specify the `steps_per_epoch` argument'): - model.fit(dataset, epochs=1, verbose=0) - with self.assertRaisesRegexp(ValueError, - 'you should specify the `steps` argument'): - model.evaluate(dataset, verbose=0) + # As sample size is 10, we batch by 4 so that the last batch is + # a partial batch. + dataset_with_partial_batch = dataset.batch(4) + cpu_model.set_weights(model_with_ds_strategy.get_weights()) - with self.assertRaisesRegexp(ValueError, - 'you should specify the `steps` argument'): - model.predict(dataset, verbose=0) + self.assertAllClose( + model_with_ds_strategy.predict(dataset_with_partial_batch, steps=3), + cpu_model.predict(dataset_with_partial_batch, steps=3), + atol=1e-5, rtol=1e-5) - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_calling_with_unsupported_predefined_callbacks(self, distribution): + @combinations.generate(tpu_strategy_combinations()) + def test_predict_multi_output_model_with_dataset_with_partial_batch( + self, distribution): with self.cached_session(): - model = get_model() - optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - - dataset = get_dataset(distribution) - - def schedule(_): - return 0.001 - with self.assertRaisesRegexp(ValueError, - 'You must specify a Keras Optimizer V2 when ' - 'using'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, - callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) - with self.assertRaisesRegexp(ValueError, - 'You must specify a Keras Optimizer V2 when ' - 'using'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, - callbacks=[keras.callbacks.ReduceLROnPlateau()]) + with distribution.scope(): + model_with_ds_strategy = simple_multi_inputs_multi_outputs_model() + model_with_ds_strategy.compile(optimizer, loss) + cpu_model = simple_multi_inputs_multi_outputs_model() + cpu_model.compile(optimizer, loss) -class TestDistributionStrategyWithLossMasking(test.TestCase, - parameterized.TestCase): + input_data, _ = get_multi_inputs_multi_outputs_data() + input_dict = { + 'input_a': input_data['input_a'], + 'input_b': input_data['input_b'], + } - # TODO(priyag): Enable all strategies for this test. Currently it does not - # work for TPU due to some invalid datatype. - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph', 'eager'])) - def test_masking(self, distribution): - with self.cached_session(): - np.random.seed(1337) - x = np.array([[[1], [1]], [[0], [0]]]) - model = keras.models.Sequential() - model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(1, kernel_initializer='one'))) - model.compile(loss='mse', - optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=distribution) - y = np.array([[[1], [1]], [[1], [1]]]) - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) - dataset = dataset.repeat(100) - dataset = dataset.batch(10) - hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2) - self.assertEqual(hist.history['loss'][0], 0) + dataset = dataset_ops.Dataset.from_tensor_slices(input_dict) + # As sample size is 200, we batch by 18 using 12 steps per epoch so + # that the last batch is a partial batch. + dataset_with_partial_batch = dataset.batch(18) + cpu_model.set_weights(model_with_ds_strategy.get_weights()) -class TestDistributionStrategyWithNormalizationLayer( - test.TestCase, parameterized.TestCase): + self.assertAllClose( + model_with_ds_strategy.predict(dataset_with_partial_batch, steps=12), + cpu_model.predict(dataset_with_partial_batch, steps=12), + atol=1e-4, rtol=1e-4) - @combinations.generate(all_strategy_combinations()) - def test_batchnorm_correctness(self, distribution): - with self.cached_session(): - model = keras.models.Sequential() - norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) - model.add(norm) - model.compile(loss='mse', - optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=distribution) - - # centered on 5.0, variance 10.0 - x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) - x = x.astype('float32') - dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) - dataset = dataset.repeat(100) - dataset = batch_wrapper(dataset, 32, distribution) - predict_dataset = dataset_ops.Dataset.from_tensor_slices(x) - predict_dataset = predict_dataset.repeat(100) - predict_dataset = batch_wrapper(predict_dataset, 32, distribution) +class TestRegularizerLoss(test.TestCase, parameterized.TestCase): + class IdentityRegularizer(keras.regularizers.Regularizer): - model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) - out = model.predict(predict_dataset, steps=2) - out -= keras.backend.eval(norm.beta) - out /= keras.backend.eval(norm.gamma) - np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) - np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + def __call__(self, x): + return array_ops.identity(x) + class AddLayer(keras.layers.Layer): -class TestDistributionStrategyCorrectness(test.TestCase, - parameterized.TestCase): + def build(self, _): + self.v = self.add_weight( + 'v', (), initializer='ones', + regularizer=TestRegularizerLoss.IdentityRegularizer()) - @combinations.generate(all_strategy_combinations()) - def test_metric_correctness(self, distribution): - with self.cached_session(): - keras.backend.set_image_data_format('channels_last') - num_samples = 10000 - - x_train = np.random.randint(0, 2, num_samples) - x_train = np.reshape(x_train, (num_samples, 1)) - y_train = x_train - x_train = x_train.astype('float32') - y_train = y_train.astype('float32') - - # Create identity model. - model = keras.Sequential() - model.add( - keras.layers.Dense(1, input_shape=(1,), kernel_initializer='ones')) - model.compile( - loss=keras.losses.mean_squared_error, - optimizer=gradient_descent.GradientDescentOptimizer(0.5), - metrics=[keras.metrics.BinaryAccuracy()], - distribute=distribution) + def call(self, inputs): + return inputs + self.v - batch_size = 64 - if not distributed_training_utils.global_batch_size_supported( - distribution): - batch_size //= distribution.num_replicas_in_sync - train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) - train_dataset = batch_wrapper(train_dataset, batch_size, distribution) + @staticmethod + def loss_fn(_, y_pred): + return math_ops.reduce_mean(y_pred) - history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) - self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) + @combinations.generate(all_strategy_combinations_minus_default()) + def test_regularizer_loss(self, distribution): + batch_size = 2 + if not distributed_training_utils.global_batch_size_supported(distribution): + batch_size //= distribution.num_replicas_in_sync + + # Given an input x, which is always 1, and variable v, this model computes + # Loss=x+v+regularizer_loss, where regularizer_loss=v and the variable is + # initialized to 1. Therefore, this model computes Loss=1+2v, and so the + # gradient dLoss/dv = 2. This gradient of 2 is averaged over all examples + # in a batch and then multiplied by the learning rate of 1. As a result, + # the model update for one batch should subtract 2 from v, resulting in v + # being -1. If the regularizer loss is not scaled correctly by number of + # replicas, the variable value will be incorrect when number of replicas + # >1. For e.g. it will be -2 if num replicas = 2. + with distribution.scope(): + x = keras.layers.Input(shape=(), batch_size=batch_size) + y = TestRegularizerLoss.AddLayer()(x) + model = keras.models.Model(inputs=x, outputs=y) + opt = gradient_descent_keras.SGD(1.) + model.compile(opt, loss=TestRegularizerLoss.loss_fn) + model.fit( + x=np.array([[1.], [1.]], dtype=np.float32), + y=np.array([[1.], [1.]], dtype=np.float32), + batch_size=batch_size) + v = model.get_weights()[0] + self.assertEqual(-1.0, v) + + +class TestDistributionStrategyWithKerasModels(test.TestCase, + parameterized.TestCase): @combinations.generate(all_strategy_combinations()) - def test_eval_metrics_correctness(self, distribution): - with self.cached_session(): - model = keras.Sequential() - model.add( - keras.layers.Dense( - 3, activation='relu', input_dim=4, kernel_initializer='ones')) - model.add( - keras.layers.Dense( - 1, activation='sigmoid', kernel_initializer='ones')) - model.compile( - loss='mae', - metrics=['accuracy', keras.metrics.BinaryAccuracy()], - optimizer=gradient_descent.GradientDescentOptimizer(0.001), - distribute=distribution) - - # verify correctness of stateful and stateless metrics. - x = np.ones((100, 4)).astype('float32') - y = np.ones((100, 1)).astype('float32') - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() - dataset = batch_wrapper(dataset, 4, distribution) - outs = model.evaluate(dataset, steps=10) - self.assertEqual(outs[1], 1.) - self.assertEqual(outs[2], 1.) - - y = np.zeros((100, 1)).astype('float32') - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() - dataset = batch_wrapper(dataset, 4, distribution) - outs = model.evaluate(dataset, steps=10) - self.assertEqual(outs[1], 0.) - self.assertEqual(outs[2], 0.) - - @combinations.generate(strategy_and_input_combinations()) - def test_correctness(self, distribution, use_numpy, use_validation_data): - - with self.cached_session(): - default_tolerance = 1e-5 - tol_table = {} - - if isinstance(distribution, (mirrored_strategy.MirroredStrategy, - mirrored_strategy.CoreMirroredStrategy)): - # TODO(b/119257215): Weights are not exactly the same, so use larger - # tolerance for now. Predict should be related to weights. - tol_table = { - 'weights_1': 1e-4, - 'weights_2': 1e-4, - 'predict_result_1': 1e-4, - } - - keras.backend.set_image_data_format('channels_last') - np.random.seed(_RANDOM_SEED) - random_seed.set_random_seed(_RANDOM_SEED) - - # Train, eval, and predict datasets are created with the same input numpy - # arrays. - # TODO(xiejw): Change this back to 10000, once we support final partial - # batch. - num_samples = 9984 - x_train = np.random.rand(num_samples, 1) - y_train = 3 * x_train - x_train = x_train.astype('float32') - y_train = y_train.astype('float32') - x_predict = [[1.], [2.], [3.], [4.]] - - # The model is built once and the initial weights are saved. - # This is used to initialize the model for both the distribution and - # non-distribution run. In addition, we add few non-linear layers to make - # it non-trivial. - def _create_model(): - model = keras.Sequential() - model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(1)) - return model - - model = _create_model() - initial_weights = model.get_weights() - del model # avoid accident usage. - - def fit_eval_and_predict(with_distribution=None): - model = _create_model() - # We have initialized the model to the same weight for the distribution - # and non-distribution run. - model.set_weights(initial_weights) - model.compile( - loss=keras.losses.mean_squared_error, - optimizer=gradient_descent_keras.SGD(0.5), - metrics=['mse'], - distribute=with_distribution) + def test_distribution_strategy_on_sequential_model(self, distribution): + with distribution.scope(): + model = simple_sequential_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) - training_inputs, eval_inputs, predict_inputs = ( - get_correctness_test_inputs(use_numpy, use_validation_data, - with_distribution, - x_train, y_train, x_predict)) + inputs = np.zeros((20, 10), np.float32) + targets = np.zeros((20, 2), np.float32) - result = {} - result['training_history_1'] = model.fit(**training_inputs).history + model.fit(inputs, targets, epochs=1, steps_per_epoch=2) + model.predict(inputs, steps=1) + model.evaluate(inputs, targets, steps=1) - if eval_inputs is not None: - result['eval_result_1'] = model.evaluate(**eval_inputs) + @combinations.generate(all_strategy_combinations()) + def test_distribution_strategy_on_functional_model(self, distribution): + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) - result['weights_1'] = model.get_weights() - result['predict_result_1'] = model.predict(**predict_inputs) + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) - # Train and eval again to mimic user's flow. + model.fit(inputs, targets, epochs=1, steps_per_epoch=2) + model.predict(inputs, steps=1) + model.evaluate(inputs, targets, steps=1) - result['training_history_2'] = model.fit(**training_inputs).history + # TODO(b/124377929): Remove error assertions once subclassed models + # are supported in DistributedStrategy. + @combinations.generate(all_strategy_combinations_minus_default()) + def test_distribution_strategy_on_subclassed_model(self, distribution): + with distribution.scope(): + model = simple_subclassed_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) - if eval_inputs is not None: - result['eval_result_2'] = model.evaluate(**eval_inputs) + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 2), dtype=np.float32) - result['weights_2'] = model.get_weights() + with self.assertRaisesRegexp(AttributeError, 'has no attribute'): + model.fit(inputs, targets, epochs=1, steps_per_epoch=2) - return result + with self.assertRaisesRegexp(AttributeError, 'has no attribute'): + model.predict(inputs, steps=1) - results_with_ds = fit_eval_and_predict(with_distribution=distribution) - results_without_ds = fit_eval_and_predict(with_distribution=None) + with self.assertRaisesRegexp(AttributeError, 'has no attribute'): + model.evaluate(inputs, targets, steps=1) - # Verify that the weights, training history, eval results, predict outputs - # are the same within some limits of tolerance. - for key in results_with_ds: - if (key.startswith('training_history') and - isinstance(distribution, tpu_strategy.TPUStrategy) and - distribution.extended.steps_per_run > 1): - # TODO(b/119894254): Enable this test for all cases once the - # underlying bug is fixed. - continue + @combinations.generate(all_strategy_combinations_minus_default()) + def test_distribution_strategy_one_dimensional(self, distribution): + with distribution.scope(): + inp = keras.layers.Input(shape=(10,)) + out = keras.layers.Dense(3, activation='softmax')(inp) + model = keras.Model(inputs=[inp], outputs=[out]) + model.compile( + optimizer='rmsprop', + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy'], + ) - tolerance = tol_table.get(key, default_tolerance) + x = np.random.random((64, 10)).astype('float32') + y = np.random.randint(3, size=64) - self.assertAllClose( - results_with_ds[key], - results_without_ds[key], - atol=tolerance, - rtol=tolerance, - msg='Fail to assert {}.'.format(key)) + model.fit(x, y, epochs=1, steps_per_epoch=2) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/keras_utils_test.py b/tensorflow/contrib/distribute/python/keras_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..36eaee77f21a9f6d62a7c3f616d0126b7a4a8902 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_utils_test.py @@ -0,0 +1,471 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.keras models with callbacks, checkpointing with dist strategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import tempfile +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_test as keras_test_lib +from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import values +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.training import gradient_descent + + +class Counter(keras.callbacks.Callback): + """Counts the number of times each callback method was run. + + Attributes: + method_counts: dict. Contains the counts of time each callback method was + run. + """ + + def __init__(self): + self.method_counts = collections.defaultdict(int) + methods_to_count = [ + 'on_batch_begin', 'on_batch_end', 'on_epoch_begin', 'on_epoch_end', + 'on_predict_batch_begin', 'on_predict_batch_end', 'on_predict_begin', + 'on_predict_end', 'on_test_batch_begin', 'on_test_batch_end', + 'on_test_begin', 'on_test_end', 'on_train_batch_begin', + 'on_train_batch_end', 'on_train_begin', 'on_train_end' + ] + for method_name in methods_to_count: + setattr(self, method_name, + self.wrap_with_counts(method_name, getattr(self, method_name))) + + def wrap_with_counts(self, method_name, method): + + def _call_and_count(*args, **kwargs): + self.method_counts[method_name] += 1 + return method(*args, **kwargs) + + return _call_and_count + + +class TestDistributionStrategyWithCallbacks(test.TestCase, + parameterized.TestCase): + + @combinations.generate(keras_test_lib.all_strategy_combinations()) + def test_callbacks_in_fit(self, distribution): + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(optimizer='sgd', loss='mse', metrics=['mae']) + + dataset = keras_test_lib.get_dataset(distribution) + counter = Counter() + + epochs = 2 + steps_per_epoch = 5 + validation_steps = 3 + + model.fit( + dataset, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + verbose=0, + validation_data=dataset, + validation_steps=validation_steps, + callbacks=[counter]) + + if isinstance(distribution, tpu_strategy.TPUStrategy): + # TPU Strategy can have multi step training, from extended.steps_per_run + # if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch + steps_per_run = distribution.extended.steps_per_run + num_batch_call_per_epoch = steps_per_epoch // steps_per_run + if steps_per_epoch % steps_per_run: + num_batch_call_per_epoch += 1 + else: + num_batch_call_per_epoch = steps_per_epoch + + self.assertDictEqual( + counter.method_counts, { + 'on_batch_begin': epochs * num_batch_call_per_epoch, + 'on_batch_end': epochs * num_batch_call_per_epoch, + 'on_epoch_begin': epochs, + 'on_epoch_end': epochs, + 'on_test_batch_begin': epochs * validation_steps, + 'on_test_batch_end': epochs * validation_steps, + 'on_test_begin': epochs, + 'on_test_end': epochs, + 'on_train_batch_begin': epochs * num_batch_call_per_epoch, + 'on_train_batch_end': epochs * num_batch_call_per_epoch, + 'on_train_begin': 1, + 'on_train_end': 1 + }) + + @combinations.generate(keras_test_lib.all_strategy_combinations()) + def test_callbacks_in_eval(self, distribution): + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(optimizer='sgd', loss='mse', metrics=['mae']) + + dataset = keras_test_lib.get_dataset(distribution) + counter = Counter() + + model.evaluate(dataset, steps=5, callbacks=[counter]) + + self.assertDictEqual( + counter.method_counts, { + 'on_test_batch_begin': 5, + 'on_test_batch_end': 5, + 'on_test_begin': 1, + 'on_test_end': 1 + }) + + @combinations.generate(keras_test_lib.all_strategy_combinations()) + def test_callbacks_in_predict(self, distribution): + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(optimizer='sgd', loss='mse', metrics=['mae']) + + dataset = keras_test_lib.get_dataset(distribution) + counter = Counter() + + model.predict( + keras_test_lib.get_predict_dataset(dataset), + steps=5, + callbacks=[counter]) + + self.assertDictEqual( + counter.method_counts, { + 'on_predict_batch_begin': 5, + 'on_predict_batch_end': 5, + 'on_predict_begin': 1, + 'on_predict_end': 1 + }) + + +class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_validating_dataset_input_tensors_with_shape_mismatch( + self, distribution): + with self.cached_session(): + a = constant_op.constant([1, 2], shape=(1, 2)) + b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) + device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) + x = values.DistributedValues(device_map, (a, b)) + y = values.DistributedValues(device_map, (a, a)) + # Removed device and input tensor shape details from the error message + # since the order of the device and the corresponding input tensor shape + # is not deterministic over different runs. + with self.assertRaisesRegexp( + ValueError, 'Input tensor shapes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + with distribution.scope(): + distributed_training_utils.validate_distributed_dataset_inputs( + distribution, x, y) + + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_validating_dataset_input_tensors_with_dtype_mismatch( + self, distribution): + with self.cached_session(): + a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) + b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) + device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) + x = values.DistributedValues(device_map, (a, b)) + y = values.DistributedValues(device_map, (a, a)) + # Removed device and input tensor dtype details from the error message + # since the order of the device and the corresponding input tensor dtype + # is not deterministic over different runs. + with self.assertRaisesRegexp( + ValueError, 'Input tensor dtypes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + with distribution.scope(): + distributed_training_utils.validate_distributed_dataset_inputs( + distribution, x, y) + + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_unsupported_features(self, distribution): + with self.cached_session(): + with distribution.scope(): + model = keras_test_lib.get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + dataset = keras_test_lib.get_dataset(distribution) + + # Test with validation split + with self.assertRaisesRegexp( + ValueError, '`validation_split` argument is not ' + 'supported when input `x` is a dataset or a ' + 'dataset iterator.+'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + validation_split=0.5, + validation_steps=2) + + # Test with sample weight. + sample_weight = np.random.random((10,)) + with self.assertRaisesRegexp( + ValueError, '`sample_weight` argument is not supported when input ' + '`x` is a dataset or a dataset iterator.'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + sample_weight=sample_weight) + + # Test with not specifying the `steps` argument for dataset with infinite + # cardinality. + dataset = dataset.repeat() + with self.assertRaisesRegexp( + ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps_per_epoch` argument'): + model.fit(dataset, epochs=1, verbose=0) + with self.assertRaisesRegexp( + ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps` argument'): + model.evaluate(dataset, verbose=0) + + with self.assertRaisesRegexp( + ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps` argument'): + model.predict(dataset, verbose=0) + + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_calling_with_unsupported_predefined_callbacks(self, distribution): + with self.cached_session(): + with distribution.scope(): + model = keras_test_lib.get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + dataset = keras_test_lib.get_dataset(distribution) + + def schedule(_): + return 0.001 + + with self.assertRaisesRegexp( + ValueError, 'You must specify a Keras Optimizer V2 when ' + 'using'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + + with self.assertRaisesRegexp( + ValueError, 'You must specify a Keras Optimizer V2 when ' + 'using'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + callbacks=[keras.callbacks.ReduceLROnPlateau()]) + + +class TestDistributionStrategyWithLossMasking(test.TestCase, + parameterized.TestCase): + + # TODO(priyag): Enable all strategies for this test. Currently it does not + # work for TPU due to some invalid datatype. + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_masking(self, distribution): + with self.cached_session(): + np.random.seed(1337) + x = np.array([[[1], [1]], [[0], [0]]]) + with distribution.scope(): + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))) + model.compile( + loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01)) + y = np.array([[[1], [1]], [[1], [1]]]) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2) + self.assertEqual(hist.history['loss'][0], 0) + + +class TestDistributionStrategyWithNormalizationLayer(test.TestCase, + parameterized.TestCase): + + @combinations.generate( + combinations.times(keras_test_lib.all_strategy_combinations(), + combinations.combine(fused=[True, False]))) + def test_batchnorm_correctness(self, distribution, fused): + with self.cached_session(): + with distribution.scope(): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization( + input_shape=(10,), momentum=0.8, fused=fused) + model.add(norm) + model.compile( + loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01)) + + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) + x = x.astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) + dataset = dataset.repeat(100) + dataset = keras_test_lib.batch_wrapper(dataset, 32, distribution) + + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x) + predict_dataset = predict_dataset.repeat(100) + predict_dataset = keras_test_lib.batch_wrapper(predict_dataset, 32, + distribution) + + model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) + out = model.predict(predict_dataset, steps=2) + out -= keras.backend.eval(norm.beta) + out /= keras.backend.eval(norm.gamma) + np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) + np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + + +class TestDistributionStrategySaveLoadWeights(test.TestCase, + parameterized.TestCase): + + @combinations.generate( + keras_test_lib.all_strategy_combinations_minus_default()) + def test_save_load_h5(self, distribution): + with self.cached_session(): + dataset = keras_test_lib.get_dataset(distribution) + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(gradient_descent_keras.SGD(0.01), 'mse') + model.fit(dataset, epochs=1, steps_per_epoch=1) + + weights_file = tempfile.mktemp('.h5') + model.save_weights(weights_file) + + model_2 = keras_test_lib.get_model() + model_2.compile(gradient_descent_keras.SGD(0.01), 'mse') + model_2.load_weights(weights_file) + model_2.predict( + keras_test_lib.get_predict_dataset(distribution), steps=2) + model_2.fit(dataset, epochs=1, steps_per_epoch=1) + + @combinations.generate( + keras_test_lib.all_strategy_combinations_minus_default()) + def test_save_load_trackable(self, distribution): + # TODO(sourabhbajaj): Test fails with optimizer v2 without h5 + with self.cached_session(): + dataset = keras_test_lib.get_dataset(distribution) + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') + model.fit(dataset, epochs=1, steps_per_epoch=1) + + weights_file = tempfile.mktemp() + model.save_weights(weights_file) + + model_2 = keras_test_lib.get_model() + model_2.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') + model_2.load_weights(weights_file) + model_2.predict( + keras_test_lib.get_predict_dataset(distribution), steps=2) + model_2.fit(dataset, epochs=1, steps_per_epoch=1) + + +class TestDistributionStrategyValidation(test.TestCase, parameterized.TestCase): + + @combinations.generate( + keras_test_lib.all_strategy_combinations_minus_default()) + def test_layer_outside_scope(self, distribution): + with self.cached_session(): + with self.assertRaisesRegexp( + ValueError, 'was not created in the distribution strategy'): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + with distribution.scope(): + model = keras.Model(x, y) + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + @combinations.generate( + keras_test_lib.all_strategy_combinations_minus_default()) + def test_model_outside_scope(self, distribution): + with self.cached_session(): + with self.assertRaisesRegexp( + ValueError, 'was not created in the distribution strategy'): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + with distribution.scope(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 8ac659abe96370b751ed1556cc699fe20788a0fd..a663e809dd45ea099e1d8a08e681d07b05bee3c9 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -95,16 +95,15 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): with ops.Graph().as_default(), distribution.scope(): - iterator = distribution.distribute_dataset( - dataset_fn).make_initializable_iterator() + iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn()) if isinstance(distribution, tpu_strategy.TPUStrategy): def step_fn(ctx, inputs): - value, update = distribution.call_for_each_replica( - metric_fn, args=inputs) + value, update = distribution.extended.call_for_each_replica( + metric_fn, args=(inputs,)) ctx.set_non_tensor_output(name="value", output=value) return distribution.group(update) - ctx = distribution.run_steps_on_dataset( + ctx = distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=distribution.extended.steps_per_run) update = ctx.run_op value = ctx.non_tensor_outputs["value"] @@ -114,15 +113,14 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): distribution.num_replicas_in_sync * distribution.extended.steps_per_run) else: - value, update = distribution.call_for_each_replica( - metric_fn, iterator.get_next()) + value, update = distribution.extended.call_for_each_replica( + metric_fn, args=(iterator.get_next(),)) update = distribution.group(update) # TODO(josh11b): Once we switch to using a global batch size for input, # replace "distribution.num_replicas_in_sync" with "1". batches_per_update = distribution.num_replicas_in_sync - self.evaluate(iterator.initializer) - self.evaluate(distribution.initialize()) + self.evaluate(iterator.initialize()) self.evaluate(variables.local_variables_initializer()) batches_consumed = 0 @@ -136,8 +134,6 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): if batches_consumed >= 4: # Consume 4 input batches in total. break - self.evaluate(distribution.finalize()) - @combinations.generate(all_combinations() + tpu_combinations()) def testMean(self, distribution): def _dataset_fn(): diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index f09483cb56b66fd4720ee71085203c14f1ccadc3..f06c9b75644b2890b7657f75e74e4e20a6f15705 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -41,12 +41,9 @@ from tensorflow.python.ops.losses import losses_impl class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): - def _get_iterator(self, ds): - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() - self.evaluate(iterator.initializer) + def _get_iterator(self, strategy, input_fn): + iterator = strategy.make_input_fn_iterator(lambda _: input_fn()) + self.evaluate(iterator.initialize()) return iterator @combinations.generate( @@ -67,15 +64,15 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.extended.call_for_each_replica( + model_fn, args=(inputs,))) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): - return distribution.run_steps_on_dataset( + return distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=2).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -84,12 +81,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): weights, biases = [], [] for _ in range(5): run_step() - weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) - self.evaluate(distribution.finalize()) - error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(is_not_increasing) @@ -105,11 +99,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): return distribution.group( - distribution.call_for_each_replica( + distribution.extended.call_for_each_replica( model_fn, args=(iterator.get_next(),))) if not context.executing_eagerly(): @@ -152,7 +146,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # `distribution.scope`. with variable_scope.variable_creator_scope( appending_creator), distribution.scope(): - model_fn, dataset_fn, layer = minimize_loss_example( + model_fn, dataset_fn, _ = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=True, @@ -161,24 +155,21 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.extended.call_for_each_replica( + model_fn, args=(inputs,))) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): - return distribution.run_steps_on_dataset( + return distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=1).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) - run_step() - self.evaluate(distribution.finalize()) - def get_expected_variables(optimizer_fn, num_parameter_devices): variables_map = { "GradientDescent": ["dense/kernel", "dense/bias"], @@ -197,7 +188,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): self.assertEqual( get_expected_variables(optimizer_fn, - len(distribution.parameter_devices)), + len(distribution.extended.parameter_devices)), set(created_variables)) @combinations.generate( @@ -230,18 +221,18 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused fetches = distribution.unwrap( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.extended.call_for_each_replica( + model_fn, args=(inputs,))) if update_ops_in_cross_replica_mode: fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) return control_flow_ops.group(fetches) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): - return distribution.run_steps_on_dataset( + return distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=1).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -267,8 +258,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) - self.evaluate(distribution.finalize()) - @combinations.generate( combinations.times( combinations.combine( @@ -302,8 +291,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): with distribution.scope(): all_vars = [] - def model_fn(x, y): - + def model_fn(inputs): + x, y = inputs def loss_fn(): # Use fixed initialization to make the steps deterministic. w = variable_scope.get_variable("w", initializer=[[2.]]) @@ -327,15 +316,15 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.extended.call_for_each_replica( + model_fn, args=(inputs,))) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): - return distribution.run_steps_on_dataset( + return distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=1).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -370,8 +359,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # One of the mean loss reductions. self.assertNear(weight, 2 + 10.6, 0.0001) - self.evaluate(distribution.finalize()) - @combinations.generate( combinations.times( combinations.distributions_and_v1_optimizers(), @@ -412,8 +399,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): return (train_op, loss) def step_fn(output_context, inputs): - (train_op, loss) = distribution.call_for_each_replica( - model_fn, args=(output_context,) + inputs) + (train_op, loss) = distribution.extended.call_for_each_replica( + model_fn, args=(output_context, inputs)) output_context.set_last_step_output( name="cross_replica_loss_reduced", output=loss, @@ -423,7 +410,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): output=loss) return distribution.group(train_op) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): initial_loss = lambda: constant_op.constant(1e7) @@ -439,7 +426,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): "cross_replica_loss_not_reduced": distribution.unwrap(distribution.broadcast(initial_loss())) } - ctx = distribution.run_steps_on_dataset( + ctx = distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=2, initial_loop_values=initial_loop_values) @@ -458,7 +445,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): reduced=False, distribution=distribution) return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"]) - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -471,8 +457,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) - self.evaluate(distribution.finalize()) - loss_is_not_increasing = all(y <= x for x, y in zip(losses, losses[1:])) self.assertTrue(loss_is_not_increasing) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 20f1a08d4261b931a9353738147fba7d7dff9225..5391e083fc9b3ed99cc64bbed11bdeb8dea07f93 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -18,17 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools - -from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import mirrored_strategy -from tensorflow.python.distribute import values # pylint: disable=protected-access,invalid-name _call_for_each_replica = mirrored_strategy._call_for_each_replica -_reduce_non_distributed_value = mirrored_strategy._reduce_non_distributed_value _create_mirrored_variable = mirrored_strategy._create_mirrored_variable all_local_devices = mirrored_strategy.all_local_devices CoreMirroredStrategy = mirrored_strategy.MirroredStrategy @@ -50,8 +46,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): distributed environment. There are several important concepts for distributed TensorFlow, e.g. - `client`, `job`, 'task', `cluster`, `in-graph replication` and - 'synchronous training' and they have already been defined in the + `client`, `job`, `task`, `cluster`, `in-graph replication` and + `synchronous training` and they have already been defined in the [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). The distribution strategy inherits these concepts as well and in addition to that we also clarify several more concepts: @@ -106,6 +102,61 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): auto_shard_dataset) super(MirroredStrategy, self).__init__(extended) + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def make_dataset_iterator(self, dataset): # pylint: disable=useless-super-delegation + """Makes an iterator for input provided via `dataset`. + + NOTE: The batch size of the `dataset` argument is treated differently for + this contrib version of `MirroredStrategy`. + + Data from the given dataset will be distributed evenly across all the + compute replicas. We will assume that the input dataset is batched by the + per-replica batch size. + + The user could also use `make_input_fn_iterator` if they want to + customize which input is fed to which replica/worker etc. + + Args: + dataset: `tf.data.Dataset` that will be distributed evenly across all + replicas. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. + """ + return super(MirroredStrategy, self).make_dataset_iterator(dataset) + + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def experimental_make_numpy_iterator( # pylint: disable=useless-super-delegation + self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None): + """Makes an iterator for input provided via a nest of numpy arrays. + + NOTE: The `batch_size` argument here has different behavior for this + contrib version of `MirroredStrategy`. + + Args: + numpy_input: A nest of NumPy input arrays that will be distributed evenly + across all replicas. + batch_size: The number of entries from the array we should consume in one + step of the computation, across all replicas. This is the per-replica + batch size. The global batch size will be this times + `num_replicas_in_sync`. + num_epochs: The number of times to iterate through the examples. A value + of `None` means repeat forever. + shuffle: Size of buffer to use for shuffling the input examples. + Use `None` to disable shuffling. + session: (TensorFlow v1.x graph execution only) A session used for + initialization. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. + """ + return super(MirroredStrategy, self).experimental_make_numpy_iterator( + numpy_input, batch_size, num_epochs, shuffle, session) + class MirroredExtended(CoreMirroredExtended): """Implementation of (contrib) MirroredStrategy.""" @@ -137,24 +188,10 @@ class MirroredExtended(CoreMirroredExtended): Returns: An `InputIterator` which returns inputs for each step of the computation. """ - if self._local_mode: - worker = device_util.canonicalize("/device:CPU:0") - worker_device_pairs = [(worker, self._devices)] - else: - worker_device_pairs = self._worker_devices - return values.DatasetIterator(dataset, worker_device_pairs) - - def _distribute_dataset(self, dataset_fn): - if self._local_mode: - return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._devices) - else: - return values.MultiWorkerDataset( - functools.partial(self._call_dataset_fn, dataset_fn), - self._worker_devices, - auto_shard=self._auto_shard_dataset) + return input_lib.DatasetIterator(dataset, self._input_workers) # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): + """The contrib version of Mirrored strategy uses per-replica batch size.""" return False diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 36be5c83f8bafb6c934d1d7682b5227b1f71c089..5ce731816ccefe36c1f876c79589e448f00b86f5 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -38,8 +38,8 @@ from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op -from tensorflow.python.framework import func_graph from tensorflow.python.framework import dtypes +from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.keras.layers import core as keras_core @@ -66,8 +66,10 @@ GPU_TEST = "test_gpu" in sys.argv[0] combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=["graph", "eager"])) -class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class MirroredTwoDeviceDistributionTest( + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): def testMinimizeLoss(self, distribution): if context.executing_eagerly(): @@ -101,7 +103,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, expected = sum(range(distribution.num_replicas_in_sync)) self.assertEqual(expected, self.evaluate(reduced)) - def testMakeInputFnIterator(self, distribution): + def testMakeInputFnIteratorWithDataset(self, distribution): dataset_fn = lambda: dataset_ops.Dataset.range(10) expected_values = [[i, i+1] for i in range(0, 10, 2)] @@ -114,9 +116,48 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, expected_values) + # TODO(b/124344198): Re-enable after fixing this flaky test. + def DISABLED_testMakeInputFnIteratorWithCallable(self, distribution): + def fn(): + dataset = dataset_ops.Dataset.range(2).interleave( + (lambda _: dataset_ops.Dataset.range(10)), cycle_length=2) + it = dataset.make_one_shot_iterator() + return it.get_next + expected_values = [[i, i] for i in range(0, 10)] + + input_fn = self._input_fn_to_test_input_context( + fn, + expected_num_replicas_in_sync=2, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, + expected_values, test_reinitialize=False) + + def testNumpyIterator(self, distribution): + self._test_numpy_iterator(distribution) + def testGlobalStepUpdate(self, distribution): self._test_global_step_update(distribution) + def testAllReduceSum(self, distribution): + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self, distribution): + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self, distribution): + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self, distribution): + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self, distribution): + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self, distribution): + self._test_all_reduce_mean_gradient_tape(distribution) + def one_device_combinations(): return combinations.combine( @@ -128,25 +169,42 @@ def one_device_combinations(): mode=["graph", "eager"]) +@combinations.generate(one_device_combinations()) class MirroredOneDeviceDistributionTest( strategy_test_lib.DistributionTestBase, + strategy_test_lib.OneDeviceDistributionTestBase, parameterized.TestCase): - @combinations.generate(one_device_combinations()) def testMinimizeLoss(self, distribution): if context.executing_eagerly(): self._test_minimize_loss_eager(distribution) else: self._test_minimize_loss_graph(distribution) - @combinations.generate(one_device_combinations()) def testReplicaId(self, distribution): self._test_replica_id(distribution) - @combinations.generate(one_device_combinations()) def testCallAndMergeExceptions(self, distribution): self._test_call_and_merge_exceptions(distribution) + def testAllReduceSum(self, distribution): + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self, distribution): + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self, distribution): + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self, distribution): + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self, distribution): + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self, distribution): + self._test_all_reduce_mean_gradient_tape(distribution) + class MirroredStrategyVariableCreatorStackTest( test.TestCase, parameterized.TestCase): @@ -183,6 +241,34 @@ class MirroredStrategyVariableCreatorStackTest( expected = ("main_thread:thread_0", "main_thread:thread_1") self.assertEqual(expected, result) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) +class MirroredStrategyCallForEachReplicaTest(test.TestCase): + + def testExecutingEagerlyOutsideFunction(self, distribution): + """Verify we preserve the value of executing_eagerly_outside_functions().""" + def model_fn(): + return ops.executing_eagerly_outside_functions() + + originally = ops.executing_eagerly_outside_functions() + with distribution.scope(): + in_scope = ops.executing_eagerly_outside_functions() + in_model_fn = distribution.extended.call_for_each_replica(model_fn) + unwrapped = distribution.unwrap(in_model_fn) + self.assertEqual(in_scope, unwrapped[0]) + self.assertEqual(in_scope, originally) + + # Verify this all again, but this time in a FuncGraph. + with func_graph.FuncGraph("fg").as_default(), distribution.scope(): + in_scope = ops.executing_eagerly_outside_functions() + in_model_fn = distribution.extended.call_for_each_replica(model_fn) + unwrapped = distribution.unwrap(in_model_fn) + self.assertEqual(in_scope, unwrapped[0]) + self.assertEqual(in_scope, originally) + @combinations.generate(combinations.combine( distribution=[ @@ -193,11 +279,13 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # TODO(priyag): Modify more tests to use this helper and check more # properties. - def _test_mv_properties(self, var, name): + def _test_mv_properties(self, var, name, strategy): self.assertIsInstance(var, values.MirroredVariable) self.assertEqual(name, var.name) + self.assertIs(strategy, var.distribute_strategy) for d in var.devices: self.assertEqual(d, var.get(d).device) + self.assertIs(strategy, var.get(d)._distribute_strategy) # pylint: disable=protected-access def testVariableInFuncGraph(self, distribution): def model_fn(): @@ -209,8 +297,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): v1 = variable_scope.variable(1.0, name="foo") v2 = distribution.extended.call_for_each_replica(model_fn) - self._test_mv_properties(v1, "foo:0") - self._test_mv_properties(v2, "bar:0") + self._test_mv_properties(v1, "foo:0", distribution) + self._test_mv_properties(v2, "bar:0", distribution) def testSingleVariable(self, distribution): def model_fn(): @@ -223,7 +311,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with distribution.scope(): result = distribution.extended.call_for_each_replica(model_fn) - self._test_mv_properties(result, "foo:0") + self._test_mv_properties(result, "foo:0", distribution) def testUnnamedVariable(self, distribution): def model_fn(): @@ -233,7 +321,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with distribution.scope(): result = distribution.extended.call_for_each_replica(model_fn) - self._test_mv_properties(result, "Variable:0") + self._test_mv_properties(result, "Variable:0", distribution) def testMultipleVariables(self, distribution): def model_fn(): @@ -246,7 +334,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with distribution.scope(): result = distribution.extended.call_for_each_replica(model_fn) for i, v in enumerate(result): - self._test_mv_properties(v, "foo" + str(i) + ":0") + self._test_mv_properties(v, "foo" + str(i) + ":0", distribution) def testMultipleVariablesWithSameCanonicalName(self, distribution): def model_fn(): @@ -296,14 +384,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase): (layer2.kernel, layer2.bias), (layer3.kernel, layer3.bias)] - ds = distribution.distribute_dataset( - lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() - self.evaluate([iterator.initializer]) - + iterator = distribution.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) + self.evaluate(iterator.initialize()) features = iterator.get_next() with distribution.scope(): @@ -524,16 +607,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase): aggregation="invalid") def testNonMatchingVariableCreation(self, distribution): + self.skipTest("b/123075960") def model_fn(name): v = variable_scope.variable(1.0, name=name) ds_context.get_replica_context().merge_call(lambda _: _) return v with distribution.scope(): - names = values.DistributedValues({ - "/device:CPU:0": "foo", - "/device:GPU:0": "bar" - }) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + names = values.DistributedValues(device_map, ("foo", "bar")) with self.assertRaises(RuntimeError): _ = distribution.extended.call_for_each_replica(model_fn, args=(names,)) @@ -667,6 +749,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase): distribution.extended.worker_devices[0]).read_value())) self.assertEqual(10.0, self.evaluate(ret_v_sum)) + def testVarDistributeStrategy(self, distribution): + with distribution.scope(): + mirrored = variable_scope.variable(1.0) + replica_local = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ) + self.assertIs(distribution, mirrored.distribute_strategy) + self.assertIs(distribution, replica_local.distribute_strategy) + @combinations.generate(combinations.combine( distribution=[ @@ -1095,7 +1186,7 @@ class ReplicaLocalVariableAssignTest(test.TestCase): # When we read the value using `read_var` we should see the SUM of each of # values on each of the replicas. self.assertEqual(2.0, self.evaluate( - distribution.read_var(replica_local_var))) + distribution.extended.read_var(replica_local_var))) # Assigning 6.0 in cross replica context will assign a value of # 6.0/num_replicas to each replica. tlv_ops = replica_local_var.assign(6.0) @@ -1104,7 +1195,7 @@ class ReplicaLocalVariableAssignTest(test.TestCase): # The value on all the replicas are added before being returned by # `read_var`. self.assertEqual(6.0, self.evaluate( - distribution.read_var(replica_local_var))) + distribution.extended.read_var(replica_local_var))) def testAssignReplicaLocalVarMeanAggregation(self, distribution): def model_fn(): @@ -1123,13 +1214,13 @@ class ReplicaLocalVariableAssignTest(test.TestCase): # When we read the value using `read_var` we should see the MEAN of values # on all replicas which is the value assigned in replica context. self.assertEqual(1.0, self.evaluate( - distribution.read_var(replica_local_var))) + distribution.extended.read_var(replica_local_var))) tlv_ops = replica_local_var.assign(6.0) self.evaluate(tlv_ops) # On reading the replica local var we should get the MEAN of all values # which is equal to the value assigned. self.assertEqual(6.0, self.evaluate( - distribution.read_var(replica_local_var))) + distribution.extended.read_var(replica_local_var))) class MockModel(object): @@ -1182,14 +1273,14 @@ class MirroredStrategyDefunTest(test.TestCase): result = distribution.extended.call_for_each_replica( model_fn, args=[mock_model] + inputs) - for device in devices: - device_result = values.select_device(device, result) - device_expected_result = values.select_device(device, expected_result) + for r in range(len(devices)): + device_result = values.select_replica(r, result) + device_expected_result = values.select_replica(r, expected_result) self.assertAllClose(device_expected_result, self.evaluate(device_result)) for defun in defuns: - # PolymorphicFunctions are specialized to the current device stack, so + # `Function`s are specialized to the current device stack, so # call_for_each has one trace per device. To check that the expected set # of variables was accessed on each trace, we first retrieve each # device-specific graph function. @@ -1265,9 +1356,9 @@ class MirroredStrategyDefunTest(test.TestCase): def fn1(mock_model, factor): return mock_model(factor) - factors = values.PerReplica({"CPU:0": 5.0, "GPU:0": 3.0}) - expected_result = values.PerReplica({"CPU:0": 5.0 * 1.25, - "GPU:0": 3.0 * 1.25}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + factors = values.PerReplica(device_map, (5.0, 3.0)) + expected_result = values.PerReplica(device_map, (5.0 * 1.25, 3.0 * 1.25)) self._call_and_check(distribution, fn1, [factors], expected_result, [fn1]) def testTrain(self, distribution): @@ -1344,7 +1435,7 @@ class MultiWorkerMirroredStrategyTest( self.assertEqual(a.device, "/job:worker/task:0") self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") - def testMakeInputFnIterator(self, distribution): + def testMakeInputFnIteratorWithDataset(self, distribution): self._configure_distribution_strategy(distribution) dataset_fn = lambda: dataset_ops.Dataset.range(100) num_gpus = context.num_gpus() @@ -1365,6 +1456,32 @@ class MultiWorkerMirroredStrategyTest( self._test_input_fn_iterator( iterator, distribution.extended.worker_devices, expected_values, sess) + def DISABLED_testMakeInputFnIteratorWithCallable(self, distribution): + self._configure_distribution_strategy(distribution) + def fn(): + dataset = dataset_ops.Dataset.range(100) + it = dataset.make_one_shot_iterator() + return it.get_next + num_gpus = context.num_gpus() + num_workers = 2 + + expected_values = [] + for i in range(0, 100, num_gpus): + expected_values.append([i+j for j in range(num_gpus)] * num_workers) + + with context.graph_mode(), self.cached_session() as sess: + # `expected_input_pipeline_id` is None because the input_fn will be called + # multiple times, each with a different input_pipeline_id. + input_fn = self._input_fn_to_test_input_context( + fn, + expected_num_replicas_in_sync=num_workers*num_gpus, + expected_num_input_pipelines=num_workers, + expected_input_pipeline_id=None) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator( + iterator, distribution.extended.worker_devices, expected_values, sess, + test_reinitialize=False) + def testUpdateConfigProto(self, distribution): distribution.configure(cluster_spec={"worker": ["fake1", "fake2"]}) diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py index 17b7ab74f63f42e1ee14a82d3bffdd1df9b25857..53e35ea6b75088a3de9866973f872da4a4ce25d6 100644 --- a/tensorflow/contrib/distribute/python/monitor.py +++ b/tensorflow/contrib/distribute/python/monitor.py @@ -51,7 +51,7 @@ class Monitor(object): else: if session is None: raise ValueError("Should provide a `session` in Graph mode.") - session.run(step_callable._iterator.initializer) # pylint: disable=protected-access + session.run(step_callable.initialize()) self._run_step = session.make_callable(step_callable()) session.run(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py index 16be839e1d155003b9490fbe3da6ab85b7d2d78a..c0651610cafc06a6d5f4206f4e64d27020fae30b 100644 --- a/tensorflow/contrib/distribute/python/monitor_test.py +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -23,9 +23,9 @@ import numpy from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import monitor as monitor_lib -from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example from tensorflow.python.client import session +from tensorflow.python.distribute import one_device_strategy from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py index 8f13e9153ea7a951dd722c4549882c97e79b57fe..c4622cdd2af2f6a9c936fe554bcc2eb76f805fdc 100644 --- a/tensorflow/contrib/distribute/python/moving_averages_test.py +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -53,7 +53,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): return var, assign with distribution.scope(), self.cached_session() as sess: - var, assign = distribution.call_for_each_replica(replica_fn) + var, assign = distribution.extended.call_for_each_replica(replica_fn) variables.global_variables_initializer().run() self.assertAllClose([10.0, 11.0], var.eval()) sess.run(distribution.unwrap(assign)) @@ -79,7 +79,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): return var, assign.op with distribution.scope(), self.cached_session() as sess: - var, assign_op = distribution.call_for_each_replica(replica_fn) + var, assign_op = distribution.extended.call_for_each_replica(replica_fn) variables.global_variables_initializer().run() self.assertAllClose([0.0, 0.0], var.eval()) sess.run(distribution.unwrap(assign_op)) @@ -152,7 +152,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): return var, assign with distribution.scope(), self.cached_session() as sess: - var, assign = distribution.call_for_each_replica(replica_fn) + var, assign = distribution.extended.call_for_each_replica(replica_fn) variables.global_variables_initializer().run() self.assertAllClose([10.0, 11.0], var.eval()) sess.run(distribution.unwrap(assign)) diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 147c9b83f866fd364ea23cf7988692a7b5f61b9c..7dca13a5b41d1a2db474c44c82f1da88be84df05 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -40,6 +40,7 @@ from tensorflow.python.client import session from tensorflow.python.estimator import run_config from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import coordinator from tensorflow.python.training import server_lib ASSIGNED_PORTS = set() @@ -360,6 +361,7 @@ class IndependentWorkerTestBase(test.TestCase): self._mock_os_env = MockOsEnv() self._mock_context = test.mock.patch.object(os, 'environ', self._mock_os_env) + self._coord = coordinator.Coordinator() super(IndependentWorkerTestBase, self).setUp() self._mock_context.__enter__() @@ -368,8 +370,9 @@ class IndependentWorkerTestBase(test.TestCase): super(IndependentWorkerTestBase, self).tearDown() def _task_thread(self, task_fn, tf_config, *args, **kwargs): - os.environ['TF_CONFIG'] = json.dumps(tf_config) - task_fn(*args, **kwargs) + with self._coord.stop_on_exception(): + os.environ['TF_CONFIG'] = json.dumps(tf_config) + task_fn(*args, **kwargs) def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id, *args, **kwargs): @@ -403,3 +406,28 @@ class IndependentWorkerTestBase(test.TestCase): *args, **kwargs) threads[task_type].append(t) return threads + + def join_independent_workers(self, worker_threads): + self._coord.join(worker_threads) + + +def get_tf_config_task(): + return json.loads(os.environ['TF_CONFIG'])['task'] + + +def get_tf_config_cluster_spec(): + return json.loads(os.environ['TF_CONFIG'])['cluster'] + + +def get_task_type(): + return get_tf_config_task()['type'] + + +def get_task_index(): + return get_tf_config_task()['index'] + + +def is_chief(): + return ('chief' not in get_tf_config_cluster_spec() + and get_task_type() == 'worker' + and get_task_index() == 0) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index fdbfba4e04358451a46b23ef250dc7c534c855a0..13a501394ee1fec2dfc1427f6d16d3a4624d7747 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -18,202 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six +from tensorflow.python.distribute import one_device_strategy -from tensorflow.python.distribute import device_util -from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import values -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.util import nest - - -# TODO(josh11b): Replace asserts in this file with if ...: raise ... - - -class OneDeviceStrategy(distribute_lib.DistributionStrategy): - """A distribution strategy for running on a single device.""" - # TODO(josh11b): Do we wrap values in types to generate errors if you are - # doing something that won't work with other DistributionStrategy - # implementations? - - def __init__(self, device): - super(OneDeviceStrategy, self).__init__(OneDeviceExtended(self, device)) - - -class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): - """Implementation of OneDeviceStrategy.""" - - def __init__(self, container_strategy, device): - super(OneDeviceExtended, self).__init__(container_strategy) - self._device = device - self._default_device = device - - def _create_variable(self, next_creator, *args, **kwargs): - colocate_with = kwargs.pop("colocate_with", None) - if colocate_with is None: - with ops.device(self._device): - return next_creator(*args, **kwargs) - if isinstance(colocate_with, six.string_types): - with ops.device(colocate_with): - return next_creator(*args, **kwargs) - if (isinstance(colocate_with, (list, tuple)) and len(colocate_with) == 1 and - isinstance(colocate_with[0], six.string_types)): - with ops.device(colocate_with[0]): - return next_creator(*args, **kwargs) - with ops.colocate_with(colocate_with): - return next_creator(*args, **kwargs) - - def _make_dataset_iterator(self, dataset): - """Make iterator from dataset without splitting the batch.""" - worker = device_util.canonicalize("/device:CPU:0") - worker_device_pairs = [(worker, [self._device])] - return values.DatasetIterator(dataset, worker_device_pairs) - - def _distribute_dataset(self, dataset_fn): - return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), [self._device]) - - def _make_input_fn_iterator( - self, - input_fn, - replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - worker = device_util.canonicalize("/device:CPU:0") - worker_device_pairs = [(worker, [self._device])] - return values.InputFunctionIterator( - input_fn, worker_device_pairs, - [distribute_lib.InputContext()]) - - def _broadcast_to(self, tensor, destinations): - del destinations - return tensor - - # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. - def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, - initial_loop_values=None): - if initial_loop_values is None: - initial_loop_values = {} - initial_loop_values = nest.flatten(initial_loop_values) - - ctx = values.MultiStepContext() - def body(i, *args): - """A wrapper around `fn` to create the while loop body.""" - del args - fn_inputs = iterator.get_next() - if not isinstance(fn_inputs, tuple): - fn_inputs = (fn_inputs,) - fn_result = fn(ctx, fn_inputs) - flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) - with ops.control_dependencies([fn_result]): - return [i + 1] + flat_last_step_outputs - - # We capture the control_flow_context at this point, before we run `fn` - # inside a while_loop. This is useful in cases where we might need to exit - # these contexts and get back to the outer context to do some things, for - # e.g. create an op which should be evaluated only once at the end of the - # loop on the host. One such usage is in creating metrics' value op. - self._outer_control_flow_context = ( - ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - - # TODO(priyag): Use max_iterations instead of an explicit counter. - cond = lambda i, *args: i < iterations - i = constant_op.constant(0) - loop_result = control_flow_ops.while_loop( - cond, body, [i] + initial_loop_values, name="", - parallel_iterations=1, back_prop=False, swap_memory=False, - return_same_structure=True) - del self._outer_control_flow_context - - ctx.run_op = control_flow_ops.group(loop_result) - - # Convert the last_step_outputs from a list to the original dict structure - # of last_step_outputs. - last_step_tensor_outputs = loop_result[1:] - last_step_tensor_outputs_dict = nest.pack_sequence_as( - ctx.last_step_outputs, last_step_tensor_outputs) - - ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access - return ctx - - def _call_for_each_replica(self, fn, args, kwargs): - strategy = self._container_strategy() - with ops.device(self._device), _OneDeviceReplicaContext(strategy): - return fn(*args, **kwargs) - - def _reduce_to(self, reduce_op, value, destinations): - del reduce_op, destinations - return value - - def _update(self, var, fn, args, kwargs, group): - # The implementations of _update() and _update_non_slot() are identical - # except _update() passes `var` as the first argument to `fn()`. - return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) - - def _update_non_slot(self, colocate_with, fn, args, kwargs, group): - del colocate_with - with ops.device(self._device), distribute_lib.UpdateContext(self._device): - result = fn(*args, **kwargs) - if group: - return result - else: - return nest.map_structure(self._unwrap, result) - - def read_var(self, replica_local_var): - """Read the aggregate value of a replica-local variable.""" - return array_ops.identity(replica_local_var) - - def _unwrap(self, value): - return (value,) - - def value_container(self, value): - return value - - @property - def _num_replicas_in_sync(self): - return 1 - - @property - def worker_devices(self): - return (self._device,) - - @property - def parameter_devices(self): - return (self._device,) - - def non_slot_devices(self, var_list): - del var_list - return (self._device,) - - @property - def experimental_should_init(self): - return True - - @property - def should_checkpoint(self): - return True - - @property - def should_save_summary(self): - return True - - # TODO(priyag): Delete this once all strategies use global batch size. - @property - def _global_batch_size(self): - return True - - -class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): - """ReplicaContext for OneDeviceStrategy.""" - - def __init__(self, distribution_strategy): - distribute_lib.ReplicaContext.__init__( - self, - distribution_strategy, - replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)) - - @property - def devices(self): - return self._distribution_strategy.extended.worker_devices +OneDeviceStrategy = one_device_strategy.OneDeviceStrategy diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index d46cd6f529e363f76bfa2b22339add63530cfde8..0e56f663d6a1ed7945befd933f2f4a83c5f64342 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -18,34 +18,35 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.eager import test -from tensorflow.python.framework import test_util -class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): +@combinations.generate(combinations.combine( + distribution=[ + combinations.one_device_strategy, + combinations.one_device_strategy_gpu], + mode=["eager", "graph"])) +class OneDeviceStrategyTest( + strategy_test_lib.DistributionTestBase, + strategy_test_lib.OneDeviceDistributionTestBase): - def _get_distribution_strategy(self): - return one_device_strategy.OneDeviceStrategy("/device:CPU:0") + def testMinimizeLoss(self, distribution): + if context.executing_eagerly(): + self._test_minimize_loss_eager(distribution) + else: + self._test_minimize_loss_graph(distribution) - def testMinimizeLossEager(self): - self._test_minimize_loss_eager(self._get_distribution_strategy()) + def testReplicaId(self, distribution): + self._test_replica_id(distribution) - def testMinimizeLossGraph(self): - self._test_minimize_loss_graph(self._get_distribution_strategy()) + def testCallAndMergeExceptions(self, distribution): + self._test_call_and_merge_exceptions(distribution) - def testReplicaId(self): - self._test_replica_id(self._get_distribution_strategy()) - - @test_util.run_in_graph_and_eager_modes - def testCallAndMergeExceptions(self): - self._test_call_and_merge_exceptions(self._get_distribution_strategy()) - - @test_util.run_in_graph_and_eager_modes - def testMakeInputFnIterator(self): - d = one_device_strategy.OneDeviceStrategy("/device:CPU:0") + def testMakeInputFnIteratorWithDataset(self, distribution): dataset_fn = lambda: dataset_ops.Dataset.range(10) expected_values = [[i] for i in range(10)] input_fn = self._input_fn_to_test_input_context( @@ -53,9 +54,46 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): expected_num_replicas_in_sync=1, expected_num_input_pipelines=1, expected_input_pipeline_id=0) - iterator = d.make_input_fn_iterator(input_fn) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator( + iterator, distribution.extended.worker_devices, expected_values) + + def testMakeInputFnIteratorWithCallable(self, distribution): + def fn(): + dataset = dataset_ops.Dataset.range(10) + it = dataset.make_one_shot_iterator() + return it.get_next + expected_values = [[i] for i in range(10)] + input_fn = self._input_fn_to_test_input_context( + fn, + expected_num_replicas_in_sync=1, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + iterator = distribution.make_input_fn_iterator(input_fn) self._test_input_fn_iterator( - iterator, d.extended.worker_devices, expected_values) + iterator, distribution.extended.worker_devices, expected_values, + test_reinitialize=False) + + def testNumpyIterator(self, distribution): + self._test_numpy_iterator(distribution) + + def testAllReduceSum(self, distribution): + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self, distribution): + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self, distribution): + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self, distribution): + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self, distribution): + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self, distribution): + self._test_all_reduce_mean_gradient_tape(distribution) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index fa4705af7cb592119f56686d1f693a156f7b4b13..e388061b17a9b92dedbbf9839049b13c8575a22c 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -41,21 +41,17 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): with distribution.scope(): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - - ds = distribution.distribute_dataset(dataset_fn) - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() + iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn()) def run_step(): - return control_flow_ops.group(distribution.unwrap( - distribution.call_for_each_replica( - model_fn, args=(iterator.get_next(),)))) + return control_flow_ops.group( + distribution.unwrap( + distribution.extended.call_for_each_replica( + model_fn, args=(iterator.get_next(),)))) if not context.executing_eagerly(): with self.cached_session() as sess: - sess.run(iterator.initializer) + sess.run(iterator.initialize()) run_step = sess.make_callable(run_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 2c7766f95fbcb7b68a53ad0052f21485c763a1db..e42bc50fdc4e5e93c998708b0790fdea7768faf2 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -18,34 +18,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy - -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib -from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.distribute import values -from tensorflow.python.eager import context -from tensorflow.python.framework import device as tf_device -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import device_setter -from tensorflow.python.util import nest +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import parameter_server_strategy +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver + +# pylint: disable=protected-access,invalid-name,line-too-long +CoreParameterServerStrategy = parameter_server_strategy.ParameterServerStrategy +CoreParameterServerExtended = parameter_server_strategy.ParameterServerStrategyExtended -_LOCAL_CPU = "/device:CPU:0" -_LOCAL_GPU_0 = "/device:GPU:0" +# pylint: enable=protected-access,invalid-name,line-too-long -# TODO(yuefengz): maybe cache variables on local CPU. -# TODO(yuefengz): we may want to set session options to disallow communication -# between workers. class ParameterServerStrategy(distribute_lib.DistributionStrategy): """A parameter server DistributionStrategy. + *** contrib version *** + This strategy class works for both local training and between-graph replicated training for multiple workers. If `cluster_spec` is specified, either passed in to __init__() method or parsed from the @@ -80,9 +70,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): variables. 3) It is also not recommended to open a colocation scope (i.e. calling - `tf.colocate_with`) under the strategy's scope. For colocating variables, - use `distribution.colocate_vars_with` instead. Colocation of ops will possibly - create conflicts of device assignment. + `tf.colocate_with`) under the strategy's scope. For colocating variables, use + `strategy.extended.colocate_vars_with` instead. Colocation of ops will + possibly create conflicts of device assignment. """ def __init__(self, num_gpus_per_worker=0): @@ -99,431 +89,84 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): super(ParameterServerStrategy, self).__init__( ParameterServerExtended(self, num_gpus_per_worker)) + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def make_dataset_iterator(self, dataset): # pylint: disable=useless-super-delegation + """Makes an iterator for input provided via `dataset`. -class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): - """Implementation of ParameterServerStrategy.""" + NOTE: The batch size of the `dataset` argument is treated differently for + this contrib version of `ParameterServerStrategy`. - def __init__(self, container_strategy, num_gpus_per_worker): - super(ParameterServerExtended, self).__init__(container_strategy) - self._num_gpus_per_worker = num_gpus_per_worker - self._initialize_local(num_gpus_per_worker) + Data from the given dataset will be distributed evenly across all the + compute replicas. We will assume that the input dataset is batched by the + per-replica batch size. - # We typically don't need to do all-reduce in this strategy. - self._cross_device_ops = ( - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( - reduce_to_device=_LOCAL_CPU)) - - def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec, - task_type, task_id): - """Initialize devices for multiple workers. - - It creates variable devices and compute devices. Variables and operations - will be assigned to them respectively. We have one compute device per - replica. The variable device is a device function or device string. The - default variable device assigns variables to parameter servers in a - round-robin fashion. + The user could also use `make_input_fn_iterator` if they want to + customize which input is fed to which replica/worker etc. Args: - num_gpus_per_worker: number of local GPUs or GPUs per worker. - cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the - cluster configurations. - task_type: the current task type. - task_id: the current task id. + dataset: `tf.data.Dataset` that will be distributed evenly across all + replicas. - Raises: - ValueError: if the cluster_spec doesn't have ps jobs. + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. """ - assert cluster_spec - if not task_type or task_id is None: - raise ValueError("When `cluster_spec` is given, you must also specify " - "`task_type` and `task_id`") - cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - - self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id) - - # Define compute devices which is a list of device strings and one for each - # replica. When there are GPUs, replicate operations on these GPUs. - # Otherwise, place operations on CPU. - if num_gpus_per_worker > 0: - self._compute_devices = tuple( - "%s/device:GPU:%d" % (self._worker_device, i) - for i in range(num_gpus_per_worker) - ) - else: - self._compute_devices = (self._worker_device,) - - self._compute_devices = tuple( - map(device_util.resolve, self._compute_devices)) - self._canonical_compute_device_set = set(self._compute_devices) - - # In distributed mode, place variables on ps jobs in a round-robin fashion. - # Note that devices returned from `replica_device_setter` are not - # canonical and therefore we don't canonicalize all variable devices to - # make them consistent. - # TODO(yuefengz): support passing a strategy object to control variable - # assignment. - # TODO(yuefengz): merge the logic of replica_device_setter into this - # class. - num_ps_replicas = len(cluster_spec.as_dict().get("ps", [])) - if num_ps_replicas == 0: - raise ValueError("The cluster spec needs to have `ps` jobs.") - self._variable_device = device_setter.replica_device_setter( - ps_tasks=num_ps_replicas, - worker_device=self._worker_device, - merge_devices=True, - cluster=cluster_spec) - - # The `_parameter_devices` is needed for the `parameter_devices` property - # and is a list of all variable devices. Here parameter devices are all - # tasks of the "ps" job. - self._parameter_devices = tuple(map("/job:ps/task:{}".format, - range(num_ps_replicas))) - - # Add a default device so that ops without specified devices will not end up - # on other workers. - self._default_device = self._worker_device - - self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, - task_id) - self._cluster_spec = cluster_spec - self._task_type = task_type - self._task_id = task_id - - logging.info( - "Multi-worker ParameterServerStrategy with " - "cluster_spec = %r, task_type = %r, task_id = %r, " - "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, " - "variable_device = %r", cluster_spec.as_dict(), task_type, task_id, - num_ps_replicas, self._is_chief, self._compute_devices, - self._variable_device) - - def _initialize_local(self, num_gpus_per_worker): - """Initialize internal devices for local training.""" - self._worker_device = device_util.canonicalize("/device:CPU:0") - # Define compute devices which is a list of device strings and one for each - # replica. When there are GPUs, replicate operations on these GPUs. - # Otherwise, place operations on CPU. - if num_gpus_per_worker > 0: - self._compute_devices = tuple( - map("/device:GPU:{}".format, range(num_gpus_per_worker))) - else: - self._compute_devices = (_LOCAL_CPU,) - - self._compute_devices = tuple( - map(device_util.resolve, self._compute_devices)) - self._canonical_compute_device_set = set(self._compute_devices) - - # If there is only one GPU, put everything on that GPU. Otherwise, place - # variables on CPU. - if num_gpus_per_worker == 1: - assert len(self._compute_devices) == 1 - self._variable_device = _LOCAL_GPU_0 - self._parameter_devices = (_LOCAL_GPU_0,) - else: - self._variable_device = _LOCAL_CPU - self._parameter_devices = (_LOCAL_CPU,) - - self._is_chief = True - self._cluster_spec = None - self._task_type = None - self._task_id = None - - logging.info( - "ParameterServerStrategy with compute_devices = %r, " - "variable_device = %r", self._compute_devices, self._variable_device) - - def _distribute_dataset(self, dataset_fn): - """Distributes the dataset to each local GPU.""" - return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._compute_devices, True) - - def _make_dataset_iterator(self, dataset): - worker_device_pairs = [(self._worker_device, self._compute_devices)] - return values.DatasetIterator(dataset, worker_device_pairs, - self._num_replicas_in_sync) - - def _make_input_fn_iterator( - self, - input_fn, - replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - """Distributes the dataset to each local GPU.""" - if self._cluster_spec: - input_pipeline_id = multi_worker_util.id_in_cluster( - self._cluster_spec, self._task_type, self._task_id) - num_input_pipelines = multi_worker_util.worker_count( - self._cluster_spec, self._task_type) - else: - input_pipeline_id = 0 - num_input_pipelines = 1 - input_context = distribute_lib.InputContext( - num_input_pipelines=num_input_pipelines, - input_pipeline_id=input_pipeline_id, - num_replicas_in_sync=self._num_replicas_in_sync) - worker_device_pairs = [(self._worker_device, self._compute_devices)] - return values.InputFunctionIterator( - input_fn, worker_device_pairs, [input_context]) - - def _broadcast_to(self, tensor, destinations): - # This is both a fast path for Python constants, and a way to delay - # converting Python values to a tensor until we know what type it - # should be converted to. Otherwise we have trouble with: - # global_step.assign_add(1) - # since the `1` gets broadcast as an int32 but global_step is int64. - if isinstance(tensor, (float, int)): - return tensor - if not cross_device_ops_lib.check_destinations(destinations): - destinations = self._compute_devices - return self._cross_device_ops.broadcast(tensor, destinations) - - def _allow_variable_partition(self): - return not context.executing_eagerly() - - # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through - # this creator, such as "MutableHashTable". - def _create_variable(self, next_creator, *args, **kwargs): - if self._num_replicas_in_sync > 1: - aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) - if aggregation not in ( - vs.VariableAggregation.NONE, - vs.VariableAggregation.SUM, - vs.VariableAggregation.MEAN, - vs.VariableAggregation.ONLY_FIRST_REPLICA - ): - raise ValueError("Invalid variable aggregation mode: " + aggregation + - " for variable: " + kwargs["name"]) - - def var_creator(*args, **kwargs): - """Create an AggregatingVariable and fix up collections.""" - # Record what collections this variable should be added to. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - - # Create and wrap the variable. - v = next_creator(*args, **kwargs) - wrapped = values.AggregatingVariable(v, aggregation) - - # Add the wrapped variable to the requested collections. - # The handling of eager mode and the global step matches - # ResourceVariable._init_from_args(). - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the contained - # variable to the TRAINABLE_VARIABLES collection, so we manually - # remove it and replace with the wrapper. We can't set "trainable" - # to False for next_creator() since that causes functions like - # implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - l.remove(v) - g.add_to_collections(collections, wrapped) - elif ops.GraphKeys.GLOBAL_STEP in collections: - ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped) - - return wrapped - else: - var_creator = next_creator - - if "colocate_with" in kwargs: - with ops.device(None): - with ops.colocate_with(kwargs["colocate_with"]): - return var_creator(*args, **kwargs) - - with ops.colocate_with(None, ignore_existing=True): - with ops.device(self._variable_device): - return var_creator(*args, **kwargs) - - def _call_for_each_replica(self, fn, args, kwargs): - # pylint: disable=protected-access - return mirrored_strategy._call_for_each_replica( - self._container_strategy(), fn, args, kwargs) + return super(ParameterServerStrategy, self).make_dataset_iterator(dataset) - def _verify_destinations_not_different_worker(self, destinations): - if not self._cluster_spec: - return - if destinations is None: - return - for d in cross_device_ops_lib.get_devices_from(destinations): - d_spec = tf_device.DeviceSpec.from_string(d) - if d_spec.job == self._task_type and d_spec.task != self._task_id: - raise ValueError( - "Cannot reduce to another worker: %r, current worker is %r" % - (d, self._worker_device)) + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def experimental_make_numpy_iterator( # pylint: disable=useless-super-delegation + self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None): + """Makes an iterator for input provided via a nest of numpy arrays. - def _reduce_to(self, reduce_op, value, destinations): - self._verify_destinations_not_different_worker(destinations) - if not isinstance(value, values.DistributedValues): - # pylint: disable=protected-access - return mirrored_strategy._reduce_non_distributed_value( - self, reduce_op, value, destinations) - return self._cross_device_ops.reduce( - reduce_op, value, destinations=destinations) - - def _batch_reduce_to(self, reduce_op, value_destination_pairs): - for _, destinations in value_destination_pairs: - self._verify_destinations_not_different_worker(destinations) - return self._cross_device_ops.batch_reduce(reduce_op, - value_destination_pairs) - - def _select_single_value(self, structured): - """Select any single values in `structured`.""" - - def _select_fn(x): # pylint: disable=g-missing-docstring - if isinstance(x, values.Mirrored): - if len(x.devices) == 1: - return list(x._index.values())[0] # pylint: disable=protected-access - else: - raise ValueError( - "You cannot update variable with a Mirrored object with multiple " - "components %r when using ParameterServerStrategy. You must " - "specify a single value or a Mirrored with a single value." % x) - elif isinstance(x, values.PerReplica): - raise ValueError( - "You cannot update variable with a PerReplica object %r when using " - "ParameterServerStrategy. You must specify a single value or a " - "Mirrored with a single value" % x) - else: - return x - - return nest.map_structure(_select_fn, structured) - - def _update(self, var, fn, args, kwargs, group): - if isinstance(var, values.AggregatingVariable): - var = var.get() - if not isinstance(var, resource_variable_ops.ResourceVariable): - raise ValueError( - "You can not update `var` %r. It must be a Variable." % var) - with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): - result = fn(var, *self._select_single_value(args), - **self._select_single_value(kwargs)) - if group: - return result - else: - return nest.map_structure(self._unwrap, result) - - # TODO(yuefengz): does it need to call _select_single_value? - def _update_non_slot(self, colocate_with, fn, args, kwargs, group): - with ops.device( - colocate_with.device), distribute_lib.UpdateContext(colocate_with): - result = fn(*args, **kwargs) - if group: - return result - else: - return nest.map_structure(self._unwrap, result) - - def _unwrap(self, val): - if isinstance(val, values.DistributedValues): - # Return in a deterministic order. - if set(val.devices) == self._canonical_compute_device_set: - return tuple(val.get(device=d) for d in self._compute_devices) - return tuple(val.get(device=d) for d in sorted(val.devices)) - return (val,) - - def value_container(self, val): - if (hasattr(val, "_aggregating_container") and - not isinstance(val, values.AggregatingVariable)): - wrapper = val._aggregating_container() # pylint: disable=protected-access - if wrapper is not None: - return wrapper - return val - - def read_var(self, var): - # No need to distinguish between normal variables and replica-local - # variables. - return array_ops.identity(var) - - def _configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): - """Configures the strategy class. - - The strategy object will be re-initialized if `cluster_spec` is given but - was not passed in the constructor. + NOTE: The `batch_size` argument here has different behavior for this + contrib version of `ParameterServerStrategy`. Args: - session_config: not used currently. - cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the - cluster configurations. - task_type: the current task type. - task_id: the current task id. - - Raises: - ValueError: if `cluster_spec` is given but `task_type` or `task_id` is - not. + numpy_input: A nest of NumPy input arrays that will be distributed evenly + across all replicas. + batch_size: The number of entries from the array we should consume in one + step of the computation, across all replicas. This is the per-replica + batch size. The global batch size will be this times + `num_replicas_in_sync`. + num_epochs: The number of times to iterate through the examples. A value + of `None` means repeat forever. + shuffle: Size of buffer to use for shuffling the input examples. + Use `None` to disable shuffling. + session: (TensorFlow v1.x graph execution only) A session used for + initialization. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. """ - if not self._cluster_spec and cluster_spec: - # If a `cluster_spec` is already passed in, do nothing here. - # TODO(yuefengz): check `cluster_spec` is the same if this object has - # already been initialized with a `cluster_spec`. - if task_type is None or task_id is None: - raise ValueError("When `cluster_spec` is given, must also specify " - "`task_type` and `task_id`.") - self._cluster_spec = multi_worker_util.normalize_cluster_spec( - cluster_spec) - self._task_type = task_type - self._task_id = task_id - self._initialize_multi_worker(self._num_gpus_per_worker, - self._cluster_spec, task_type, task_id) - - if session_config: - session_config.CopyFrom(self._update_config_proto(session_config)) - - def _update_config_proto(self, config_proto): - updated_config = copy.deepcopy(config_proto) - if not self._cluster_spec: - updated_config.isolate_session_state = True - return updated_config + return super(ParameterServerStrategy, + self).experimental_make_numpy_iterator( + numpy_input, batch_size, num_epochs, shuffle, session) - updated_config.isolate_session_state = False - assert self._task_type - assert self._task_id is not None - - # The device filters prevent communication between workers. - if self._task_type not in ["chief", "worker"]: - return updated_config - del updated_config.device_filters[:] - updated_config.device_filters.extend( - ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"]) - return updated_config - - @property - def _num_replicas_in_sync(self): - return len(self._compute_devices) - - @property - def worker_devices(self): - return self._compute_devices - - @property - def parameter_devices(self): - return self._parameter_devices - - def non_slot_devices(self, var_list): - return min(var_list, key=lambda x: x.name) - - @property - def experimental_between_graph(self): - # TODO(yuefengz): Should this return False in the local case? - return True - - @property - def experimental_should_init(self): - return self._is_chief +class ParameterServerExtended(CoreParameterServerExtended): + """Implementation of ParameterServerStrategy.""" - @property - def should_checkpoint(self): - return self._is_chief + def __init__(self, container_strategy, num_gpus_per_worker): + # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change + # the constructor's interface to allow customized cluster resolver. Use + # SimpleClusterResolver to override num_accelerators. + tfconfig = TFConfigClusterResolver() + cluster_resolver = SimpleClusterResolver( + cluster_spec=tfconfig.cluster_spec(), + task_type=tfconfig.task_type, + task_id=tfconfig.task_id, + num_accelerators=num_gpus_per_worker) + super(ParameterServerExtended, self).__init__( + container_strategy, cluster_resolver=cluster_resolver) - @property - def should_save_summary(self): - return self._is_chief + def _make_dataset_iterator(self, dataset): + return input_lib.DatasetIterator(dataset, self._input_workers) # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): + """The contrib version of PS strategy uses per-replica batch size.""" return False diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 83d7473666a65e438a1c0119d2a12bf54e53c8fc..3de2041ae35775de6df5bca02c0f1d04a9c2f24e 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -29,10 +29,13 @@ from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import parameter_server_strategy as core_parameter_server_strategy from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.estimator import run_config @@ -45,10 +48,12 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import training_util +from tensorflow.python.training.server_lib import ClusterSpec CHIEF = run_config.TaskType.CHIEF WORKER = run_config.TaskType.WORKER @@ -62,6 +67,57 @@ def _get_replica_id_integer(): return replica_id +class MockCoreParameterServerStrategy(distribute_lib.DistributionStrategy): + """Mock the strategy to allow cluster resolver as an argument.""" + + def __init__(self, cluster_resolver): + super(MockCoreParameterServerStrategy, self).__init__( + core_parameter_server_strategy.ParameterServerStrategyExtended( + self, cluster_resolver=cluster_resolver)) + + +def create_test_objects(cluster_spec=None, + task_type=None, + task_id=None, + num_gpus=None, + sess_config=None, + use_core_strategy=False): + sess_config = sess_config or config_pb2.ConfigProto() + if num_gpus is None: + num_gpus = context.num_gpus() + if use_core_strategy: + if cluster_spec and task_type and task_id is not None: + cluster_resolver = SimpleClusterResolver( + cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), + task_type=task_type, + task_id=task_id, + num_accelerators=num_gpus) + target = 'grpc://' + cluster_spec[WORKER][task_id] + else: + cluster_resolver = SimpleClusterResolver( + ClusterSpec({}), num_accelerators=num_gpus) + target = '' + + distribution = MockCoreParameterServerStrategy(cluster_resolver) + sess_config = copy.deepcopy(sess_config) + sess_config = distribution.update_config_proto(sess_config) + else: + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=num_gpus) + if task_type: + sess_config = copy.deepcopy(sess_config) + distribution.configure( + session_config=sess_config, + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id) + target = 'grpc://' + cluster_spec[WORKER][task_id] + else: + target = '' + + return distribution, target, sess_config + + class ParameterServerStrategyTestBase( multi_worker_test_base.MultiWorkerTestBase): @@ -75,24 +131,27 @@ class ParameterServerStrategyTestBase( self._sess_config = config_pb2.ConfigProto(allow_soft_placement=True) super(ParameterServerStrategyTestBase, self).setUp() - def _get_test_objects(self, task_type, task_id, num_gpus): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=num_gpus) - if not task_type: - return distribution, '', self._sess_config - - sess_config = copy.deepcopy(self._sess_config) - distribution.configure( - session_config=sess_config, + def _get_test_objects(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): + return create_test_objects( cluster_spec=self._cluster_spec, task_type=task_type, - task_id=task_id) - return (distribution, 'grpc://' + self._cluster_spec[WORKER][task_id], - sess_config) - - def _test_device_assignment_distributed(self, task_type, task_id, num_gpus): + task_id=task_id, + num_gpus=num_gpus, + sess_config=self._sess_config, + use_core_strategy=use_core_strategy) + + def _test_device_assignment_distributed(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) - d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) + d, _, sess_config = self._get_test_objects( + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) with ops.Graph().as_default(), \ self.cached_session(target=self._default_target, config=sess_config) as sess, \ @@ -131,7 +190,7 @@ class ParameterServerStrategyTestBase( '/job:worker/replica:0/task:0/%s' % last_part_device) # The colocate_vars_with can override the distribution's device. - with d.colocate_vars_with(x): + with d.extended.colocate_vars_with(x): y = variable_scope.get_variable( 'y', initializer=20.0, aggregation=variable_scope.VariableAggregation.SUM) @@ -177,7 +236,7 @@ class ParameterServerStrategyTestBase( self.assertIn('/job:ps/', h.device) return y_add, z_add, f - y, z, f = d.call_for_each_replica(model_fn) + y, z, f = d.extended.call_for_each_replica(model_fn) self.assertNotEqual(y, None) self.assertNotEqual(z, None) self.assertNotEqual(f, None) @@ -190,9 +249,10 @@ class ParameterServerStrategyTestBase( self.assertEqual(f_val, 46.0) def _test_device_assignment_distributed_enable_partitioner( - self, task_type, task_id, num_gpus): - d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) - num_shards = len(d.parameter_devices) + self, task_type, task_id, num_gpus, use_core_strategy=False): + d, _, sess_config = self._get_test_objects( + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) + num_shards = len(d.extended.parameter_devices) partitioner = partitioned_variables.fixed_size_partitioner(num_shards) with ops.Graph().as_default(), \ self.cached_session(target=self._default_target, @@ -224,39 +284,18 @@ class ParameterServerStrategyTestBase( self.assertEqual(var.device, '/job:ps/task:%d' % part_id) self.assertEqual(var.device, x_add[part_id].device) - # The colocate_vars_with can override the distribution's device. - with d.colocate_vars_with(x_add[0]): - y = variable_scope.get_variable( - 'y', - initializer=constant_op.constant([20.0, 10.0]), - aggregation=variable_scope.VariableAggregation.SUM, - partitioner=partitioner) - y_add = y.assign_add( - [array_ops.identity(x_add[0]), - array_ops.identity(x_add[1])]) - - for part_id, var in enumerate(y): - self.assertEqual(var.device, '/job:ps/task:0') - self.assertEqual(y_add[part_id].device, var.device) - self.assertEqual(var.device, x_add[0].device) - - return x_add, y_add + return x_add - x, y = d.call_for_each_replica(model_fn) + x = d.extended.call_for_each_replica(model_fn) if context.num_gpus() >= 1: variables.global_variables_initializer().run() - x_val, y_val = sess.run([x, y]) + x_val = sess.run(x) if num_gpus < 1: self.assertEqual(x_val, [13.0, 25.0]) - self.assertEqual(y_val, [33.0, 35.0]) else: x_expect = [10.0 + 3 * num_gpus, 20.0 + 5 * num_gpus] - y_expect = [ - 20.0 + x_expect[0] * num_gpus, 10.0 + x_expect[1] * num_gpus - ] self.assertEqual(x_val, x_expect) - self.assertEqual(y_val, y_expect) def _test_device_assignment_local(self, d, @@ -305,7 +344,7 @@ class ParameterServerStrategyTestBase( self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2')) # The colocate_vars_with can override the distribution's device. - with d.colocate_vars_with(x): + with d.extended.colocate_vars_with(x): y = variable_scope.get_variable( 'y', initializer=20.0, aggregation=variable_scope.VariableAggregation.SUM) @@ -348,7 +387,7 @@ class ParameterServerStrategyTestBase( device_util.canonicalize(h.device)) return y_add, z_add, f - y, z, f = d.call_for_each_replica(model_fn) + y, z, f = d.extended.call_for_each_replica(model_fn) self.assertNotEqual(y, None) self.assertNotEqual(z, None) self.assertNotEqual(f, None) @@ -360,9 +399,13 @@ class ParameterServerStrategyTestBase( self.assertEqual(z_val, 43.0) self.assertEqual(f_val, 46.0) - def _test_simple_increment(self, task_type, task_id, num_gpus): + def _test_simple_increment(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): d, master_target, sess_config = self._get_test_objects( - task_type, task_id, num_gpus) + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) if d.extended._cluster_spec: num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) if 'chief' in d.extended._cluster_spec.as_dict(): @@ -395,7 +438,7 @@ class ParameterServerStrategyTestBase( train_op = control_flow_ops.group(x_add, y_add, z_add) return x, y, z, train_op - x, y, z, train_op = d.call_for_each_replica(model_fn) + x, y, z, train_op = d.extended.call_for_each_replica(model_fn) train_op = d.group(train_op) if context.num_gpus() < d.extended._num_gpus_per_worker: @@ -430,9 +473,13 @@ class ParameterServerStrategyTestBase( y_val == 20.0 + 1.0 * num_workers * d.num_replicas_in_sync and z_val == 30.0 + 1.0 * num_workers) - def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + def _test_minimize_loss_graph(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): d, master_target, sess_config = self._get_test_objects( - task_type, task_id, num_gpus) + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) if task_type: # Multi-worker assert hasattr(d.extended, '_cluster_spec') and d.extended._cluster_spec @@ -472,20 +519,20 @@ class ParameterServerStrategyTestBase( def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, args=(one,)) + g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] for g, v in g_v: - fetched = d.read_var(v) + fetched = d.extended.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.extended.reduce_to( reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( - d.update(v, update, g, grouped=False)): - after_list.append(d.read_var(v)) + d.extended.update(v, update, args=(g,), group=False)): + after_list.append(d.extended.read_var(v)) return before_list, after_list before_out, after_out = step() @@ -518,10 +565,16 @@ class ParameterServerStrategyTestBase( self.assertLess(error_after, error_before) return error_after < error_before - def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, - expected_values): + def _test_input_fn_iterator(self, + task_type, + task_id, + num_gpus, + input_fn, + expected_values, + test_reinitialize=True, + use_core_strategy=False): distribution, master_target, config = self._get_test_objects( - task_type, task_id, num_gpus) + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) devices = distribution.extended.worker_devices with ops.Graph().as_default(), \ @@ -532,27 +585,31 @@ class ParameterServerStrategyTestBase( for expected_value in expected_values: next_element = iterator.get_next() - computed_value = sess.run( - [values.select_device(d, next_element) for d in devices]) + computed_value = sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - sess.run([values.select_device(d, next_element) for d in devices]) + sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. - sess.run(iterator.initialize()) + if test_reinitialize: + sess.run(iterator.initialize()) - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = sess.run( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, computed_value) + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) + self.assertEqual(expected_value, computed_value) -class ParameterServerStrategyTest(ParameterServerStrategyTestBase, - strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class ParameterServerStrategyTest( + ParameterServerStrategyTestBase, + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): @classmethod def setUpClass(cls): @@ -560,111 +617,175 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, num_workers=3, num_ps=2) cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0] - def test_num_replicas_in_sync(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def test_num_replicas_in_sync(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) # All the devices on a given worker are in sync which in this case is the # number of gpus on each worker. - self.assertEqual(2, distribution.num_replicas_in_sync) + self.assertEqual(2, strategy.num_replicas_in_sync) - def testDeviceAssignmentLocalCPU(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=0) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testDeviceAssignmentLocalCPU(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=0, use_core_strategy=use_core_strategy) self._test_device_assignment_local( - distribution, compute_device='CPU', variable_device='CPU', num_gpus=0) + strategy, compute_device='CPU', variable_device='CPU', num_gpus=0) - def testDeviceAssignmentLocalOneGPU(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=1) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testDeviceAssignmentLocalOneGPU(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=1, use_core_strategy=use_core_strategy) self._test_device_assignment_local( - distribution, compute_device='GPU', variable_device='GPU', num_gpus=1) + strategy, compute_device='GPU', variable_device='GPU', num_gpus=1) - def testDeviceAssignmentLocalTwoGPUs(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testDeviceAssignmentLocalTwoGPUs(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) self._test_device_assignment_local( - distribution, compute_device='GPU', variable_device='CPU', num_gpus=2) + strategy, compute_device='GPU', variable_device='CPU', num_gpus=2) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testDeviceAssignmentDistributed(self, num_gpus): - self._test_device_assignment_distributed('worker', 1, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testDeviceAssignmentDistributed(self, num_gpus, use_core_strategy): + self._test_device_assignment_distributed( + 'worker', 1, num_gpus, use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus): + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus, + use_core_strategy): self._test_device_assignment_distributed_enable_partitioner( - 'worker', 1, num_gpus) + 'worker', 1, num_gpus, use_core_strategy=use_core_strategy) - def testSimpleBetweenGraph(self): - self._run_between_graph_clients(self._test_simple_increment, - self._cluster_spec, context.num_gpus()) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testSimpleBetweenGraph(self, use_core_strategy): + self._run_between_graph_clients( + self._test_simple_increment, + self._cluster_spec, + context.num_gpus(), + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testLocalSimpleIncrement(self, num_gpus): - self._test_simple_increment(None, 0, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testLocalSimpleIncrement(self, num_gpus, use_core_strategy): + self._test_simple_increment(None, 0, num_gpus, use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraphDistributed(self, num_gpus): - self._run_between_graph_clients(self._test_minimize_loss_graph, - self._cluster_spec, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testMinimizeLossGraphDistributed(self, num_gpus, use_core_strategy): + self._run_between_graph_clients( + self._test_minimize_loss_graph, + self._cluster_spec, + num_gpus, + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraphLocal(self, num_gpus): - self._test_minimize_loss_graph(None, None, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testMinimizeLossGraphLocal(self, num_gpus, use_core_strategy): + self._test_minimize_loss_graph(None, None, num_gpus, use_core_strategy) + # TODO(b/124344198): Re-enable after fixing this flaky test. # TODO(priyag): Refactor this and other multi worker tests. @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) - def testMakeInputFnIteratorDistributed(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[1, 2], + required_gpus=1, + use_core_strategy=[True, False], + use_dataset=[True, False])) + def DISABLED_testMakeInputFnIteratorDistributed( + self, num_gpus, use_core_strategy, use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - dataset_fn = lambda: dataset_ops.Dataset.range(100) + if use_dataset: + fn = lambda: dataset_ops.Dataset.range(100) + else: + def fn(): + dataset = dataset_ops.Dataset.range(100) + it = dataset.make_one_shot_iterator() + return it.get_next expected_values = [[i+j for j in range(num_gpus)] for i in range(0, 100, num_gpus)] input_fn = self._input_fn_to_test_input_context( - dataset_fn, + fn, expected_num_replicas_in_sync=num_gpus, expected_num_input_pipelines=3, expected_input_pipeline_id=1) # because task_id = 1 - self._test_input_fn_iterator('worker', 1, num_gpus, - input_fn, expected_values) - + self._test_input_fn_iterator( + 'worker', + 1, + num_gpus, + input_fn, + expected_values, + test_reinitialize=use_dataset, + use_core_strategy=use_core_strategy) + + # TODO(b/124344198): Re-enable after fixing this flaky test. @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) - def testMakeInputFnIteratorLocal(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[1, 2], + required_gpus=1, + use_core_strategy=[True, False], + use_dataset=[True, False])) + def DISABLED_testMakeInputFnIteratorLocal(self, num_gpus, use_core_strategy, + use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - dataset_fn = lambda: dataset_ops.Dataset.range(100) + if use_dataset: + fn = lambda: dataset_ops.Dataset.range(100) + else: + def fn(): + dataset = dataset_ops.Dataset.range(100) + it = dataset.make_one_shot_iterator() + return it.get_next expected_values = [[i+j for j in range(num_gpus)] for i in range(0, 100, num_gpus)] input_fn = self._input_fn_to_test_input_context( - dataset_fn, + fn, expected_num_replicas_in_sync=num_gpus, expected_num_input_pipelines=1, expected_input_pipeline_id=0) # only one worker and pipeline for local. - self._test_input_fn_iterator(None, None, num_gpus, - input_fn, expected_values) + self._test_input_fn_iterator( + None, + None, + num_gpus, + input_fn, + expected_values, + test_reinitialize=use_dataset, + use_core_strategy=use_core_strategy) - def testGlobalStepUpdate(self): - strategy = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=context.num_gpus()) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testGlobalStepUpdate(self, use_core_strategy): + strategy, _, _ = create_test_objects(use_core_strategy=use_core_strategy) self._test_global_step_update(strategy) - def testUpdateConfigProtoMultiWorker(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) - distribution.configure( + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testUpdateConfigProtoMultiWorker(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + strategy.configure( cluster_spec=self._cluster_spec, task_type='worker', task_id=1) config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) - new_config = distribution.update_config_proto(config_proto) + new_config = strategy.update_config_proto(config_proto) # Verify device filters. self.assertEqual(['/job:worker/task:1', '/job:ps'], @@ -673,16 +794,48 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, # Verify isolate_session_state self.assertFalse(new_config.isolate_session_state) - def testUpdateConfigProtoLocal(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testUpdateConfigProtoLocal(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) config_proto = config_pb2.ConfigProto() - new_config = distribution.update_config_proto(config_proto) + new_config = strategy.update_config_proto(config_proto) # Verify isolate_session_state self.assertTrue(new_config.isolate_session_state) + def testAllReduceSum(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean_gradient_tape(distribution) + class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, parameterized.TestCase): @@ -693,20 +846,31 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, num_workers=3, num_ps=2, has_chief=True) cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0] - def testSimpleBetweenGraph(self): - self._run_between_graph_clients(self._test_simple_increment, - self._cluster_spec, context.num_gpus()) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testSimpleBetweenGraph(self, use_core_strategy): + self._run_between_graph_clients( + self._test_simple_increment, + self._cluster_spec, + context.num_gpus(), + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraph(self, num_gpus): - self._run_between_graph_clients(self._test_minimize_loss_graph, - self._cluster_spec, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testMinimizeLossGraph(self, num_gpus, use_core_strategy): + self._run_between_graph_clients( + self._test_minimize_loss_graph, + self._cluster_spec, + num_gpus, + use_core_strategy=use_core_strategy) - def testGlobalStepIsWrapped(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) - with ops.Graph().as_default(), distribution.scope(): + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testGlobalStepIsWrappedOnTwoGPUs(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + with ops.Graph().as_default(), strategy.scope(): created_step = training_util.create_global_step() get_step = training_util.get_global_step() self.assertEqual(created_step, get_step, @@ -715,19 +879,55 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, id(get_step), get_step.__class__.__name__))) self.assertIs(values.AggregatingVariable, type(created_step)) self.assertIs(values.AggregatingVariable, type(get_step)) + self.assertIs(strategy, created_step.distribute_strategy) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testGlobalStepIsNotWrappedOnOneGPU(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=1, use_core_strategy=use_core_strategy) + with ops.Graph().as_default(), strategy.scope(): + created_step = training_util.create_global_step() + get_step = training_util.get_global_step() + self.assertEqual(created_step, get_step, + msg=('created_step %s type %s vs. get_step %s type %s' % + (id(created_step), created_step.__class__.__name__, + id(get_step), get_step.__class__.__name__))) + self.assertIs(resource_variable_ops.ResourceVariable, type(created_step)) + self.assertIs(resource_variable_ops.ResourceVariable, type(get_step)) + # All variables have an _distribute_strategy parameter. Only variable + # subclasses in distribution strategy expose it publicly. + self.assertFalse(hasattr(strategy, 'distribute_strategy')) + self.assertIs(strategy, created_step._distribute_strategy) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testValueContainer(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + with ops.Graph().as_default(), strategy.scope(): - def testValueContainer(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) - with ops.Graph().as_default(), distribution.scope(): def f(): with backprop.GradientTape() as tape: v = variable_scope.get_variable('v', initializer=10.0) _ = v * v v, = tape.watched_variables() - w = distribution.extended.value_container(v) + w = strategy.extended.value_container(v) self.assertIs(values.AggregatingVariable, type(w)) - distribution.extended.call_for_each_replica(f) + + strategy.extended.call_for_each_replica(f) + + +class LocalParameterServerStrategyTest(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine(mode=['graph', 'eager'], + use_core_strategy=[True, False], + required_gpus=2)) + def testNumpyIterator(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + self._test_numpy_iterator(strategy) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index c928b6d9f1f21508edd753f94c38ab2723cc0a9f..27aad46b97195aa498d0382f08c04c312cebbe65 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.eager import backprop -from tensorflow.python.eager import context from tensorflow.python.training import optimizer as optimizer_lib @@ -33,6 +32,9 @@ class Step(object): def distribution(self): return self._distribution + def initialize(self): + return [] + def __call__(self): """Perform one step of this training algorithm.""" raise NotImplementedError("must be implemented in descendants") @@ -50,12 +52,10 @@ class StandardInputStep(Step): def __init__(self, dataset_fn, distribution): super(StandardInputStep, self).__init__(distribution) - self._distributed_input = distribution.distribute_dataset(dataset_fn) - if context.executing_eagerly(): - self._iterator = self._distributed_input.make_one_shot_iterator() - else: - # TODO(priyag): Expose initializer via some initializer property. - self._iterator = self._distributed_input.make_initializable_iterator() + self._iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn()) + + def initialize(self): + return self._iterator.initialize() class StandardSingleLossStep(StandardInputStep): @@ -99,8 +99,8 @@ class StandardSingleLossStep(StandardInputStep): gradients_fn = backprop.implicit_grad(self._loss_fn) gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) - grads_and_vars = self.distribution.call_for_each_replica( - gradients_fn, args=(ctx,) + inputs) + grads_and_vars = self.distribution.extended.call_for_each_replica( + gradients_fn, args=(ctx, inputs)) # If threads use layers, then we need to run the first step # sequentially, so that layers.build() is not executed in parallel. # Otherwise, multiple sets of mirrored variables are going to be @@ -109,6 +109,6 @@ class StandardSingleLossStep(StandardInputStep): self.distribution, grads_and_vars) # TODO(priyag): Return the outputs, context, etc as well. - ctx = self.distribution.run_steps_on_dataset( + ctx = self.distribution.extended.experimental_run_steps_on_iterator( step_fn, self._iterator, self._iterations_per_step) return ctx.run_op diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index 1ff9b9ceec13351b098d47ed3ff62f689a625a31..9f48560b2666036e149a63c98b6529fb24cc5067 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -45,24 +45,21 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): single_loss_step, layer = single_loss_example( optimizer_fn, distribution, use_bias=True, iterations_per_step=2) - self.evaluate(distribution.initialize()) if context.executing_eagerly(): + single_loss_step.initialize() run_step = single_loss_step else: with self.cached_session() as sess: - sess.run(single_loss_step._iterator.initializer) + sess.run(single_loss_step.initialize()) run_step = sess.make_callable(single_loss_step()) self.evaluate(variables.global_variables_initializer()) weights, biases = [], [] for _ in range(5): run_step() - weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) - self.evaluate(distribution.finalize()) - error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(is_not_increasing) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index d441b5af5f6aa41efde2c75d09d9589516c54992..90f552eda4c41742f21ca276d8a059b2b102554f 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -18,7 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values @@ -31,6 +34,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -41,25 +45,26 @@ class _TestException(Exception): pass -# May be the argument to either distribution.call_for_each_replica() or +# May be the argument to either distribution.extended.call_for_each_replica() or # get_replica_context().merge_call() def _raise_exception_fn(_=None): raise _TestException() -# Must be the argument to a distribution.call_for_each_replica() call, calls a -# get_replica_context().merge_call() that raises an exception. +# Must be the argument to a distribution.extended.call_for_each_replica() call, +# calls a get_replica_context().merge_call() that raises an exception. def _merge_raises_fn(): ds_context.get_replica_context().merge_call(_raise_exception_fn) # Must be the argument to a get_replica_context().merge_call() call, calls -# dist.call_for_each_replica() with a function that raises an exception. +# dist.extended.call_for_each_replica() with a function that raises an +# exception. def _call_raises_fn(dist): - dist.call_for_each_replica(_raise_exception_fn) + dist.extended.call_for_each_replica(_raise_exception_fn) -# Must be the argument to a distribution.call_for_each_replica() call, +# Must be the argument to a distribution.extended.call_for_each_replica() call, # calls a get_replica_context().merge_call() that calls a # call_for_each_replica() that raises an exception. def _merge_call_raises_fn(): @@ -67,15 +72,16 @@ def _merge_call_raises_fn(): # Must be the argument to a get_replica_context().merge_call() call, calls -# dist.call_for_each_replica() with a function that calls a +# dist.extended.call_for_each_replica() with a function that calls a # get_replica_context().merge_call() that raises an exception. def _call_merge_raises_fn(dist): - dist.call_for_each_replica(_merge_raises_fn) + dist.extended.call_for_each_replica(_merge_raises_fn) -# Must be the argument to a distribution.call_for_each_replica() call, calls a -# get_replica_context().merge_call() that calls a call_for_each_replica() that -# calls a get_replica_context().merge_call() that raises an exception. +# Must be the argument to a distribution.extended.call_for_each_replica() call, +# calls a get_replica_context().merge_call() that calls a +# call_for_each_replica() that calls a get_replica_context().merge_call() that +# raises an exception. def _merge_call_merge_raises_fn(): ds_context.get_replica_context().merge_call(_call_merge_raises_fn) @@ -106,21 +112,21 @@ class DistributionTestBase(test.TestCase): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, args=(one,)) + g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] for g, v in g_v: - fetched = d.read_var(v) + fetched = d.extended.read_var(v) before_list.append(fetched) # control_dependencies irrelevant but harmless in eager execution with ops.control_dependencies([fetched]): g = d.extended.reduce_to( reduce_util.ReduceOp.SUM, g, destinations=v) - with ops.control_dependencies(d.update( - v, update, g, grouped=False)): - after_list.append(d.read_var(v)) + with ops.control_dependencies(d.extended.update( + v, update, args=(g,), group=False)): + after_list.append(d.extended.read_var(v)) return before_list, after_list for i in range(10): @@ -162,20 +168,20 @@ class DistributionTestBase(test.TestCase): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, args=(one,)) + g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] for g, v in g_v: - fetched = d.read_var(v) + fetched = d.extended.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): g = d.extended.reduce_to( reduce_util.ReduceOp.SUM, g, destinations=v) - with ops.control_dependencies(d.update( - v, update, g, grouped=False)): - after_list.append(d.read_var(v)) + with ops.control_dependencies(d.extended.update( + v, update, args=(g,), group=False)): + after_list.append(d.extended.read_var(v)) return before_list, after_list before_out, after_out = step() @@ -202,23 +208,23 @@ class DistributionTestBase(test.TestCase): self.assertFalse(expected_devices[replica_id]) expected_devices[replica_id] = True - d.call_for_each_replica(mark_devices_fn) + d.extended.call_for_each_replica(mark_devices_fn) self.assertAllEqual(expected_devices, [True] * len(d.extended.worker_devices)) def _test_call_and_merge_exceptions(self, dist): with dist.scope(): with self.assertRaises(_TestException): - dist.call_for_each_replica(_raise_exception_fn) + dist.extended.call_for_each_replica(_raise_exception_fn) with self.assertRaises(_TestException): - dist.call_for_each_replica(_merge_raises_fn) + dist.extended.call_for_each_replica(_merge_raises_fn) with self.assertRaises(_TestException): - dist.call_for_each_replica(_merge_call_raises_fn) + dist.extended.call_for_each_replica(_merge_call_raises_fn) with self.assertRaises(_TestException): - dist.call_for_each_replica(_merge_call_merge_raises_fn) + dist.extended.call_for_each_replica(_merge_call_merge_raises_fn) def _input_fn_to_test_input_context(self, - dataset_fn, + dataset_or_callable_fn, expected_num_replicas_in_sync, expected_num_input_pipelines, expected_input_pipeline_id): @@ -242,33 +248,35 @@ class DistributionTestBase(test.TestCase): self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id) worker_id_counter[0] += 1 - return dataset_fn() + return dataset_or_callable_fn() return _input_fn def _test_input_fn_iterator(self, iterator, devices, expected_values, - sess=None): + sess=None, test_reinitialize=True): evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) evaluate(iterator.initialize()) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( - [values.select_device(d, next_element) for d in devices]) + [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - evaluate([values.select_device(d, next_element) for d in devices]) + evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. - evaluate(iterator.initialize()) + if test_reinitialize: + evaluate(iterator.initialize()) - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, computed_value) + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate([values.select_replica(r, next_element) + for r in range(len(devices))]) + self.assertEqual(expected_value, computed_value) def _test_global_step_update(self, strategy): with strategy.scope(): @@ -286,8 +294,195 @@ class DistributionTestBase(test.TestCase): value = global_step.read_value() return train_op, value - train_ops, value = strategy.call_for_each_replica(model_fn) + train_ops, value = strategy.extended.call_for_each_replica(model_fn) self.evaluate(strategy.group(train_ops)) global_step_tensors = strategy.unwrap(value) global_step_values = self.evaluate(global_step_tensors) self.assertEqual((1,) * len(global_step_tensors), global_step_values) + + def _test_numpy_iterator(self, strategy): + with strategy.scope(), self.cached_session() as sess: + x = np.asarray([[1, 2], [6, 12], [2, 4], + [5, 10], [3, 6], [4, 8]]) + y = np.asarray([5, 4, 3, 2, 1, 0]) + batch_size = 6 + if not strategy.extended._global_batch_size: # pylint: disable=protected-access + batch_size = batch_size // strategy.num_replicas_in_sync + i = strategy.experimental_make_numpy_iterator( + (x, y), batch_size=batch_size, num_epochs=2, shuffle=None, + session=sess) + self.evaluate(i.initialize()) + + def run_and_concatenate(strategy, i): + x, y = strategy.experimental_run(lambda z: z, i) + x, y = self.evaluate((strategy.unwrap(x), strategy.unwrap(y))) + return np.concatenate(x), np.concatenate(y) + + x_1, y_1 = run_and_concatenate(strategy, i) + self.assertAllEqual(x, x_1) + self.assertAllEqual(y, y_1) + x_2, y_2 = run_and_concatenate(strategy, i) + self.assertAllEqual(x, x_2) + self.assertAllEqual(y, y_2) + with self.assertRaises(errors.OutOfRangeError): + run_and_concatenate(strategy, i) + + +class OneDeviceDistributionTestBase(test.TestCase): + """Some tests that should work with any one-device DistributionStrategy.""" + + def _test_all_reduce_sum(self, strategy): + self._test_collective_comms( + strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.])) + + def _test_all_reduce_sum_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_sum, inputs=[4.], expected_grads=[4.]) + + def _test_all_reduce_sum_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_sum, inputs=[4.], expected_grads=[4.]) + + def _test_all_reduce_mean(self, strategy): + self._test_collective_comms( + strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.])) + + def _test_all_reduce_mean_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_mean, inputs=[5.], expected_grads=[5.]) + + def _test_all_reduce_mean_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_mean, inputs=[5.], expected_grads=[5.]) + + def _test_collective_comms(self, strategy, comm_fn, inputs, expected): + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + outputs = self.evaluate( + list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs)))) + self.assertAllEqual([expected[0]], outputs[0]) + self.assertAllEqual([expected[1]], outputs[1]) + + def _test_collective_comms_gradients( + self, strategy, comm_fn, inputs, expected_grads): + if context.executing_eagerly(): + self.skipTest("`tf.gradients` is not supported with eager execution.") + + def step(c): + x = constant_op.constant(42.) + y = comm_fn(x) * c + return gradients_impl.gradients(y, [x])[0] + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + def _test_collective_comms_gradient_tape( + self, strategy, comm_fn, inputs, expected_grads): + def step(c): + x = constant_op.constant(42.) + with backprop.GradientTape() as tape: + tape.watch(x) + y = comm_fn(x) * c + return tape.gradient(y, x) + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + +class TwoDeviceDistributionTestBase(test.TestCase): + """Some tests that should work with any two-device DistributionStrategy.""" + + def _test_all_reduce_sum(self, strategy): + self._test_collective_comms( + strategy, _all_sum, + inputs=([1., 3.], [[39., 2.], [3., 41.]]), + expected=(4., [42., 43.])) + + def _test_all_reduce_sum_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.]) + + def _test_all_reduce_sum_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.]) + + def _test_all_reduce_mean(self, strategy): + self._test_collective_comms( + strategy, _all_mean, + inputs=([1., 3.], [[39., 2.], [3., 41.]]), + expected=(2., [21., 21.5])) + + def _test_all_reduce_mean_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.]) + + def _test_all_reduce_mean_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.]) + + def _test_collective_comms(self, strategy, comm_fn, inputs, expected): + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + outputs = self.evaluate( + list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs)))) + self.assertAllEqual([expected[0], expected[0]], outputs[0]) + self.assertAllEqual([expected[1], expected[1]], outputs[1]) + + def _test_collective_comms_gradients( + self, strategy, comm_fn, inputs, expected_grads): + if context.executing_eagerly(): + self.skipTest("`tf.gradients` is not supported with eager execution.") + + def step(c): + x = constant_op.constant(42.) + y = comm_fn(x) * c + return gradients_impl.gradients(y, [x])[0] + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + def _test_collective_comms_gradient_tape( + self, strategy, comm_fn, inputs, expected_grads): + def step(c): + x = constant_op.constant(42.) + with backprop.GradientTape() as tape: + tape.watch(x) + y = comm_fn(x) * c + return tape.gradient(y, x) + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + +def _all_sum(value): + ctx = ds_context.get_replica_context() + return ctx.all_reduce(reduce_util.ReduceOp.SUM, value) + + +def _all_mean(value): + ctx = ds_context.get_replica_context() + return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index b6f5b492017fc7dfd329e69ad9ca418ae682bc4b..2d9d221f427422f8bbeba55c5644658af9a9a620 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,10 +21,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import copy -import functools from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import device_assignment as device_assignment_lib +from tensorflow.contrib.tpu.python.tpu import functional as tpu_functional_ops +from tensorflow.contrib.tpu.python.tpu import topology from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.contrib.tpu.python.tpu import training_loop @@ -33,11 +36,16 @@ from tensorflow.python.client import session as session_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values +from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op +from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -46,9 +54,58 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat from tensorflow.python.util import nest +def initialize_tpu_system(cluster_resolver=None): + """Initialize the TPU devices in a separate session and graph. + + Args: + cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver, + which provides information about the TPU cluster. + Returns: + The tf.contrib.tpu.Topology object for the topology of the TPU cluster. + """ + if cluster_resolver is None: + cluster_resolver = TPUClusterResolver("") + master = cluster_resolver.master() + + logging.info("Initializing the TPU system.") + + if context.executing_eagerly(): + # This function looks as it is for the following non-intuitive reasons. + # tpu.initialize_system creates a dummy op whose sole purpose is to trigger + # DistributedTPURewritePass. This pass actually adds real ops that + # initialize the TPU system. Thus, we can't simply run tpu.initialize_system + # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. + # The easiest way to trigger a rewrite is to run the function with + # TPUPartitionedCallOp. + @function.defun + def _tpu_init_fn(): + return tpu.initialize_system() + + # We can't call _tpu_init_fn normally (because it contains just a dummy op, + # see above) but need to define it to get it added to eager context + # and get its assigned name. + # pylint: disable=protected-access + graph_func = _tpu_init_fn._get_concrete_function_internal() + func_name = compat.as_str(graph_func._inference_function.name) + # pylint: enable=protected-access + + output = tpu_functional_ops.TPUPartitionedCall( + args=[], device_ordinal=0, Tout=[dtypes.string], f=func_name) + serialized_topology = output[0].numpy() + else: + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + with ops.Graph().as_default(): + with session_lib.Session(config=session_config, target=master) as sess: + serialized_topology = sess.run(tpu.initialize_system()) + + logging.info("Finished initializing TPU system.") + return topology.Topology(serialized=serialized_topology) + + def get_tpu_system_metadata(tpu_cluster_resolver): """Retrieves TPU system metadata given a TPUClusterResolver.""" master = tpu_cluster_resolver.master() @@ -66,13 +123,14 @@ def get_tpu_system_metadata(tpu_cluster_resolver): # TODO(jhseu): Deduplicate with MirroredStrategy? -def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, - **kwargs): # pylint: disable=g-missing-docstring +def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring + strategy, device_map, logical_device, real_mirrored_creator, + *args, **kwargs): # Figure out what collections this variable should be added to. # We'll add the TPUMirroredVariable to those collections instead. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] + var_collections = kwargs.pop("collections", None) + if var_collections is None: + var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] # TODO(jhseu): Should we have different behavior for different @@ -97,10 +155,13 @@ def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): - index = real_mirrored_creator(devices, *args, **kwargs) - result = values.TPUMirroredVariable(index, index[devices[0]], aggregation) + devices = device_map.logical_to_actual_devices(logical_device) + value_list = real_mirrored_creator(devices, *args, **kwargs) + result = values.TPUMirroredVariable( + strategy, device_map, value_list, aggregation, + logical_device=logical_device) - if not context.executing_eagerly(): + if not (context.executing_eagerly() or ops.inside_function()): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove @@ -108,18 +169,21 @@ def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - for v in index.values(): + for v in value_list: l.remove(v) - g.add_to_collections(collections, result) + g.add_to_collections(var_collections, result) return result class TPUStrategy(distribute_lib.DistributionStrategy): """TPU distribution strategy implementation.""" - def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None): + def __init__(self, + tpu_cluster_resolver=None, + steps_per_run=None, + device_assignment=None): """Initializes the TPUStrategy object. Args: @@ -130,72 +194,124 @@ class TPUStrategy(distribute_lib.DistributionStrategy): metrics, summaries etc. This parameter is only used when Distribution Strategy is used with estimator or keras. - num_cores: Number of cores to use on the TPU. If None specified, then - auto-detect the cores and topology of the TPU system. + device_assignment: Optional `tf.contrib.tpu.DeviceAssignment` to specify + the placement of replicas on the TPU cluster. Currently only supports + the usecase of using a single core within a TPU cluster. """ super(TPUStrategy, self).__init__(TPUExtended( - self, tpu_cluster_resolver, steps_per_run, num_cores)) + self, tpu_cluster_resolver, steps_per_run, device_assignment)) @property def steps_per_run(self): """DEPRECATED: use .extended.steps_per_run instead.""" return self._extended.steps_per_run + # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this + # can use the default implementation. + # This implementation runs a single step. It does not use infeed or outfeed. + def experimental_run(self, fn, input_iterator=None): + """See base class.""" + if context.executing_eagerly() and not ops.inside_function(): + raise NotImplementedError( + "Eager mode not supported in TPUStrategy outside TF functions.") + + if input_iterator is None: + inputs = [] + else: + inputs = input_iterator.get_next() + + result = [None] + def replicated_fn(replica_id, replica_input): + """Wraps user function to provide replica ID and `Tensor` inputs.""" + with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): + if input_iterator is None: + result[0] = fn() + else: + result[0] = fn(replica_input) + return result[0] + + replicate_inputs = [] # By replica. + for i in range(self.num_replicas_in_sync): + replicate_inputs.append( + [constant_op.constant(i, dtype=dtypes.int32), + values.select_replica(i, inputs)]) + + with self.scope(): + replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) + + # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. + replicate_outputs = [ + nest.pack_sequence_as(result[0], nest.flatten(replica_outputs)) + for replica_outputs in replicate_outputs] + + device_map = self.extended._device_map # pylint: disable=protected-access + return values.regroup(device_map, replicate_outputs) + class TPUExtended(distribute_lib.DistributionStrategyExtended): """Implementation of TPUStrategy.""" - # Track what TPU devices have been initialized. - _initialized_devices = [] - - def __init__(self, container_strategy, tpu_cluster_resolver, steps_per_run, - num_cores=None): + def __init__(self, + container_strategy, + tpu_cluster_resolver=None, + steps_per_run=None, + device_assignment=None): super(TPUExtended, self).__init__(container_strategy) + + if tpu_cluster_resolver is None: + tpu_cluster_resolver = TPUClusterResolver("") + + if steps_per_run is None: + # TODO(frankchn): Warn when we are being used by DS/Keras and this is + # not specified. + steps_per_run = 1 + self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) - # TODO(sourabhbajaj): Change this from num_cores to metadata_override - self._num_cores_override = num_cores + self._device_assignment = device_assignment + + # Device assignment is currently only supported for 1 core case. + if self._device_assignment: + assert isinstance(self._device_assignment, + device_assignment_lib.DeviceAssignment) + if self._device_assignment.num_replicas != 1: + raise ValueError("Device assignment is only supported for a single " + "core single replica case currently.") + if self._device_assignment.num_cores_per_replica != 1: + raise ValueError("Device assignment is only supported for a single " + "core single replica case currently.") + if not all(self._device_assignment.core_assignment[0][0] == [0, 0, 0]): + raise ValueError("Device assignment is only supported for a single " + "core single replica case currently.") # TODO(jhseu): Switch to DeviceAssignment to support pods and model # parallelism. - device_map = {d.name: i for i, d in enumerate(self._tpu_metadata.devices) - if "device:TPU:" in d.name} - self._device_index = values.PerReplica(device_map) + self._device_index = { + d.name: i for i, d in enumerate(self._tpu_metadata.devices) + if "device:TPU:" in d.name + } self._host_device = self.get_host_cpu_device(0) - self._tpu_devices = tuple(sorted(device_map.keys())) + self._tpu_devices = tuple(sorted(self._device_index.keys())) # Only create variables for the number of replicas we're running. self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] + self._device_map = values.ReplicaDeviceMap(self._tpu_devices) + + # Preload the data onto the TPUs. + input_worker_devices = collections.OrderedDict() + for tpu_device in self._tpu_devices: + host_device = _get_host_for_device(tpu_device) + input_worker_devices.setdefault(host_device, []) + input_worker_devices[host_device].append(tpu_device) + self._input_workers = input_lib.InputWorkers( + self._device_map, tuple(input_worker_devices.items())) # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run self._require_static_shapes = True - # Initialize the TPU devices. - self._initialize_tpu() - - def _initialize_tpu(self): - """Initialize the TPU devices in a separate session and graph. - - We keep track of all the TPU devices that we're initialized as we should - only be running TPU initialize once for the entire process. - """ - master = self._tpu_cluster_resolver.master() - # Verify TPU has not already been initialized in this process. - if master in TPUExtended._initialized_devices: - logging.info("TPU master %s has already been initialized." % master) - return - - logging.info("Initializing the TPU system.") - session_config = config_pb2.ConfigProto(allow_soft_placement=True) - self._configure(session_config) - with ops.Graph().as_default(): - with session_lib.Session(config=session_config, target=master) as sess: - sess.run([tpu.initialize_system()]) - logging.info("Finized initializing TPU system.") - - # Update Strategy state to make sure we can track device initialization. - TPUExtended._initialized_devices.append(master) + def _validate_colocate_with_variable(self, colocate_with_variable): + values.validate_colocate_tpu_variable(colocate_with_variable, self) def _get_enqueue_op_per_host(self, host_id, multi_worker_iterator, input_shapes, iterations): @@ -260,21 +376,27 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def _make_dataset_iterator(self, dataset): """Make iterators for each of the TPU hosts.""" - - worker_devices = [ - (self.get_host(hid), [self.get_host_cpu_device(hid)]) - for hid in range(self.num_hosts) - ] - return values.DatasetIterator(dataset, worker_devices, - self._num_replicas_in_sync) - - def _distribute_dataset(self, dataset_fn): - worker_devices = [ - (self.get_host(hid), [self.get_host_cpu_device(hid)]) - for hid in range(self.num_hosts) - ] - return values.MultiWorkerDataset( - functools.partial(self._call_dataset_fn, dataset_fn), worker_devices) + return input_lib.DatasetIterator(dataset, self._input_workers, + self._num_replicas_in_sync) + + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + input_contexts = [] + num_workers = self._input_workers.num_workers + for i in range(num_workers): + input_contexts.append(distribute_lib.InputContext( + num_input_pipelines=num_workers, + input_pipeline_id=i, + num_replicas_in_sync=self._num_replicas_in_sync)) + return input_lib.InputFunctionIterator( + input_fn, self._input_workers, input_contexts) + + def _experimental_make_numpy_dataset(self, numpy_input, session): + return numpy_dataset.one_host_numpy_dataset( + numpy_input, numpy_dataset.SingleDevice(self.get_host_cpu_device(0)), + session) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have @@ -288,29 +410,16 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): "TPU currently requires fully defined shapes. Either use " "set_shape() on the input tensors or use " "dataset.batch(..., drop_remainder=True).") - types = nest.flatten(multi_worker_iterator.output_types) - - enqueue_ops = [ - self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes, - iterations) - for host_id in range(self.num_hosts)] - - def dequeue_fn(): - dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) - return nest.pack_sequence_as(output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) - ctx = values.MultiStepContext() + ctx = input_lib.MultiStepContext() - def run_fn(): + def run_fn(inputs): """Single step on the TPU device.""" - fn_inputs = dequeue_fn() - if not isinstance(fn_inputs, tuple): - fn_inputs = (fn_inputs,) - fn_result = fn(ctx, fn_inputs) + fn_result = fn(ctx, inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): @@ -330,7 +439,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args - replicate_inputs = [[]] * self._num_replicas_in_sync + + per_replica_inputs = multi_worker_iterator.get_next() + replicate_inputs = [] + for replica_id in range(self._num_replicas_in_sync): + select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop + replicate_inputs.append((nest.map_structure( + select_replica, per_replica_inputs),)) + replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We @@ -342,8 +458,8 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): return replicate_outputs - # TODO(sourabhbajaj): The input to while loop should be based on the output - # type of the step_fn + # TODO(sourabhbajaj): The input to while loop should be based on the + # output type of the step_fn assert isinstance(initial_loop_values, list) initial_loop_values = initial_loop_values * self._num_replicas_in_sync @@ -353,7 +469,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): initial_loop_values) del self._outer_control_flow_context - ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) + ctx.run_op = control_flow_ops.group(replicate_outputs) if isinstance(replicate_outputs, list): # Filter out any ops from the outputs, typically this would be the case @@ -378,23 +494,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # no tensors returned. last_step_tensor_outputs = [] - # Convert replicate_outputs to the original dict structure of - # last_step_outputs. - last_step_tensor_outputs_dict = nest.pack_sequence_as( - ctx.last_step_outputs, last_step_tensor_outputs) - - for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access - output = last_step_tensor_outputs_dict[name] - # For outputs that have already been reduced, take the first value - # from the list as each value should be the same. Else return the full - # list of values. - # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica - # value. - if reduce_op is not None: - # TODO(priyag): Should this return the element or a list with 1 element - last_step_tensor_outputs_dict[name] = output[0] - ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access - + _set_last_step_outputs(ctx, last_step_tensor_outputs) return ctx def _call_for_each_replica(self, fn, args, kwargs): @@ -403,57 +503,57 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): with _TPUReplicaContext(self._container_strategy()): return fn(*args, **kwargs) - def _initialize(self): - if context.executing_eagerly(): - # TODO(priyag): Add appopriate call here when eager is supported for TPUs. - raise NotImplementedError("Eager mode not supported in TPUStrategy.") - else: - return [] + def _experimental_initialize_system(self): + """Experimental method added to be used by Estimator. - def _finalize(self): - if context.executing_eagerly(): - # TODO(priyag): Add appopriate call here when eager is supported for TPUs. - raise NotImplementedError("Eager mode not supported in TPUStrategy.") - else: - return [] - - def _get_devices_from(self, colocate_with=None): - # TODO(jhseu): Change this when we support model parallelism. - return self._tpu_devices + This is a private method only to be used by Estimator. Other frameworks + should directly be calling `tf.contrib.distribute.initialize_tpu_system` + """ + initialize_tpu_system(self._tpu_cluster_resolver) def _create_variable(self, next_creator, *args, **kwargs): """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" colocate_with = kwargs.pop("colocate_with", None) - devices = self._get_devices_from(colocate_with) + if colocate_with is None: + device_map = self._device_map + logical_device = 0 # TODO(josh11b): Get logical device from scope here. + elif isinstance(colocate_with, numpy_dataset.SingleDevice): + with ops.device(colocate_with.device): + return next_creator(*args, **kwargs) + else: + device_map = colocate_with.device_map + logical_device = colocate_with.logical_device def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring - index = {} + value_list = [] for i, d in enumerate(devices): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: - var0name = index[devices[0]].name.split(":")[0] + var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: - if context.executing_eagerly(): - kwargs["initial_value"] = array_ops.identity( - index[devices[0]].value()) + if context.executing_eagerly() or ops.inside_function(): + with ops.init_scope(): + kwargs["initial_value"] = array_ops.identity( + value_list[0].value()) else: def initial_value_fn(device=d): with ops.device(device): - return array_ops.identity(index[devices[0]].initial_value) + return array_ops.identity(value_list[0].initial_value) kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.TPUMirroredVariable) - index[d] = v - return index + value_list.append(v) + return value_list - return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args, - **kwargs) + return _create_tpu_mirrored_variable( + self._container_strategy(), device_map, logical_device, + _real_mirrored_creator, *args, **kwargs) def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access @@ -465,19 +565,32 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) - # Validate that the destination is same as the host device - # Note we don't do this when in replicate context as the reduction is - # performed on the TPU device itself. + if not isinstance(value, values.DistributedValues): + # This function handles reducing values that are not PerReplica or + # Mirrored values. For example, the same value could be present on all + # replicas in which case `value` would be a single value or value could + # be 0. + return cross_device_ops_lib.reduce_non_distributed_value( + reduce_op, self._device_map, value, destinations) + devices = cross_device_ops_lib.get_devices_from(destinations) - if len(devices) == 1: - assert device_util.canonicalize(devices[0]) == device_util.canonicalize( - self._host_device) - else: + if len(devices) != 1: raise ValueError("Multiple devices are not supported for TPUStrategy") - output = math_ops.add_n(value) - if reduce_op == reduce_util.ReduceOp.MEAN: - return output * (1. / len(value)) + # Always performs the reduction on the TPU host. + with ops.device(self._host_device): + output = math_ops.add_n(value.values) + if reduce_op == reduce_util.ReduceOp.MEAN: + output *= (1. / len(value.values)) + + # If necessary, copy to requested destination. + dest_canonical = device_util.canonicalize(devices[0]) + host_canonical = device_util.canonicalize(self._host_device) + + if dest_canonical != host_canonical: + with ops.device(devices[0]): + output = array_ops.identity(output) + return output def _update(self, var, fn, args, kwargs, group): @@ -486,19 +599,19 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): if group: return fn(var, *args, **kwargs) else: - return [fn(var, *args, **kwargs)] + return (fn(var, *args, **kwargs),) # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. - updates = {} - for d, v in var._index.items(): # pylint: disable=protected-access - name = "update_%d" % self._device_index.get(d) + updates = [] + for i, (d, v) in enumerate(zip(var.devices, var.values)): + name = "update_%d" % i with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. - updates[d] = fn(v, - *values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, group) + updates.append(fn(v, + *values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs))) + return values.update_regroup(self, self._device_map, updates, group) def read_var(self, var): assert isinstance(var, values.TPUMirroredVariable) @@ -513,6 +626,11 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # be represented using a PerReplica wrapper instead of a list with # one entry per device. return tuple(val) + elif isinstance(val, values.TPUMirroredVariable): + # pylint: disable=protected-access + if values._enclosing_tpu_context() is not None: + return (val,) + return val.values return (val,) def value_container(self, value): @@ -524,15 +642,34 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): @property def num_hosts(self): - return self._tpu_metadata.num_hosts + if self._device_assignment is None: + return self._tpu_metadata.num_hosts + + return len(set([self._device_assignment.host_device(r) + for r in range(self._device_assignment.num_replicas)])) @property def num_replicas_per_host(self): - return self._tpu_metadata.num_of_cores_per_host + if self._device_assignment is None: + return self._tpu_metadata.num_of_cores_per_host + + # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed + # as the computation of num_replicas_per_host is not a constant + # when using device_assignment. This is a temporary workaround to support + # StatefulRNN as everything is 1 in that case. + # This method needs to take host_id as input for correct computation. + max_models_per_host = (self._tpu_metadata.num_of_cores_per_host // + self._device_assignment.num_cores_per_replica) + models_per_host = min(self._device_assignment.num_replicas, + max_models_per_host) + return models_per_host * self._device_assignment.num_cores_per_replica @property def _num_replicas_in_sync(self): - return self._num_cores_override or self._tpu_metadata.num_cores + if self._device_assignment is None: + return self._tpu_metadata.num_cores + return (self._device_assignment.num_replicas * + self._device_assignment.num_cores_per_replica) @property def experimental_between_graph(self): @@ -600,23 +737,62 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): + """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. + + `make_input_fn_iterator` assumes per-replica batching. + + Returns: + Boolean. + """ return True class _TPUReplicaContext(distribute_lib.ReplicaContext): """Replication Context class for TPU Strategy.""" - # TODO(sourabhbajaj): Call for each tower should be updating this. - def __init__(self, distribution_strategy): + # TODO(sourabhbajaj): Call for each replica should be updating this. + # TODO(b/118385803): Always properly initialize replica_id. + def __init__(self, strategy, replica_id_in_sync_group=None): + if replica_id_in_sync_group is None: + replica_id_in_sync_group = constant_op.constant(0, dtypes.int32) distribute_lib.ReplicaContext.__init__( - self, - distribution_strategy, - # TODO(b/118385803): properly initialize replica_id, instead of always 0 - replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)) + self, strategy, replica_id_in_sync_group=replica_id_in_sync_group) @property def devices(self): distribute_lib.require_replica_context(self) - ds = self._distribution_strategy + ds = self._strategy replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) - return (ds.extended.worker_devices[replica_id],) + + if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`. + # TODO(cjfj): Return other devices when model parallelism is supported. + return (tpu.core(0),) + else: + return (ds.extended.worker_devices[replica_id],) + + +def _get_host_for_device(device): + spec = tf_device.DeviceSpec.from_string(device) + return tf_device.DeviceSpec( + job=spec.job, replica=spec.replica, task=spec.task, + device_type="CPU", device_index=0).to_string() + + +def _set_last_step_outputs(ctx, last_step_tensor_outputs): + """Sets the last step outputs on the given context.""" + # Convert replicate_outputs to the original dict structure of + # last_step_outputs. + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access + output = last_step_tensor_outputs_dict[name] + # For outputs that have already been reduced, take the first value + # from the list as each value should be the same. Else return the full + # list of values. + # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica + # value. + if reduce_op is not None: + # TODO(priyag): Should this return the element or a list with 1 element + last_step_tensor_outputs_dict[name] = output[0] + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 538b859f3d1ece55b460f6dbf8f01540a6013381..9fd251175b8b8e3453e33434b4d86386a078295e 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -22,27 +22,20 @@ import os from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import device_util -from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import constant_op -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import saver as saver_lib -from tensorflow.python.util import nest class DistributedValuesTest(test.TestCase): @@ -51,7 +44,8 @@ class DistributedValuesTest(test.TestCase): with ops.device("/device:CPU:0"): one = constant_op.constant(1) two = constant_op.constant(2) - v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedValues(device_map, (one, two)) self.assertEqual(two, v.get("/device:GPU:0")) self.assertEqual(one, v.get()) with self.assertRaises(ValueError): @@ -63,24 +57,26 @@ class DistributedValuesTest(test.TestCase): ops.device("/device:CPU:0"): one = constant_op.constant(1) two = constant_op.constant(2) - v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedValues(device_map, (one, two)) self.assertEqual(two, v.get("/device:GPU:0")) self.assertEqual(one, v.get()) with self.assertRaises(ValueError): self.assertIsNone(v.get("/device:GPU:2")) def testCanonicalization(self): - canonical_cpu = ["/job:localhost/replica:0/task:0/device:CPU:0"] - v = values.DistributedValues({"": 42}) - self.assertEqual(canonical_cpu, list(v._index.keys())) - v = values.DistributedValues({"/device:CPU:0": 42}) - self.assertEqual(canonical_cpu, list(v._index.keys())) - v = values.DistributedValues({"/cpu:0": 42}) - self.assertEqual(canonical_cpu, list(v._index.keys())) - v = values.DistributedValues({"/CPU:0": 42}) - self.assertEqual(canonical_cpu, list(v._index.keys())) + canonical_cpu = ("/job:localhost/replica:0/task:0/device:CPU:0",) + v = values.DistributedValues(values.SingleDeviceMap(""), (42,)) + self.assertEqual(canonical_cpu, v.devices) + v = values.DistributedValues(values.SingleDeviceMap("/device:CPU:0"), (42,)) + self.assertEqual(canonical_cpu, v.devices) + v = values.DistributedValues(values.SingleDeviceMap("/cpu:0"), (42,)) + self.assertEqual(canonical_cpu, v.devices) + v = values.DistributedValues(values.SingleDeviceMap("/CPU:0"), (42,)) + self.assertEqual(canonical_cpu, v.devices) with self.assertRaises(AssertionError): - v = values.DistributedValues({"/device:cpu:0": 42}) + v = values.DistributedValues( + values.SingleDeviceMap("/device:cpu:0"), (42,)) def testIsTensorLike(self): with context.graph_mode(), \ @@ -88,7 +84,8 @@ class DistributedValuesTest(test.TestCase): ops.device("/device:CPU:0"): one = constant_op.constant(1) two = constant_op.constant(2) - v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedValues(device_map, (one, two)) self.assertEqual(two, v.get("/device:GPU:0")) self.assertEqual(one, v.get()) self.assertTrue(v.is_tensor_like) @@ -100,7 +97,8 @@ class DistributedValuesTest(test.TestCase): ops.device("/device:CPU:0"): one = constant_op.constant(1) two = 2.0 - v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedValues(device_map, (one, two)) self.assertEqual(two, v.get("/device:GPU:0")) self.assertEqual(one, v.get()) self.assertFalse(v.is_tensor_like) @@ -118,8 +116,8 @@ class DistributedDelegateTest(test.TestCase): def __init__(self, x): self.x = x - v = values.DistributedDelegate( - {"/device:CPU:0": Foo(7), "/device:GPU:0": Foo(8)}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedDelegate(device_map, (Foo(7), Foo(8))) self.assertEqual(7, v.x) with self.assertRaises(AttributeError): _ = v.y @@ -127,7 +125,8 @@ class DistributedDelegateTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testOperatorOverride(self): with ops.device("/device:CPU:0"): - v = values.DistributedDelegate({"/device:CPU:0": 7, "/device:GPU:0": 8}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedDelegate(device_map, (7, 8)) # v should act like int(7). self.assertEqual(8, v + 1) self.assertEqual(10, 3 + v) @@ -178,16 +177,15 @@ def _nested_value(d): def _make_mirrored(): v = [] - index = {} devices = ["/device:GPU:0", "/device:CPU:0"] for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): with ops.device(d): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) - index[d] = v[-1] - mirrored = values.MirroredVariable(index, v[0], + device_map = values.ReplicaDeviceMap(devices) + mirrored = values.MirroredVariable(None, device_map, v, variable_scope.VariableAggregation.SUM) - return v, devices, mirrored + return v, device_map, mirrored class RegroupAndSelectDeviceTest(test.TestCase): @@ -204,8 +202,9 @@ class RegroupAndSelectDeviceTest(test.TestCase): self.assertEqual(expected[i], result.get(_device_str(i))) def testNested(self): - result = values.regroup({_device_str(0): _nested_value("1"), - _device_str(1): _nested_value("2")}) + device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) + result = values.regroup(device_map, + (_nested_value("1"), _nested_value("2"))) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_replica(result[0], ["a1", "a2"]) @@ -221,11 +220,11 @@ class RegroupAndSelectDeviceTest(test.TestCase): self._is_per_replica(result[1][1]["c"], ["d1", "d2"]) self._is_per_replica(result[1][1]["e"], ["f1", "f2"]) - # Also test that we can undo the merge using select_device() + # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), - values.select_device(_device_str(0), result)) + values.select_replica(0, result)) self.assertEqual(_nested_value("2"), - values.select_device(_device_str(1), result)) + values.select_replica(1, result)) # select_device_mirrored() should fail due to non-mirrored values with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(0), result) @@ -235,8 +234,9 @@ class RegroupAndSelectDeviceTest(test.TestCase): def testWrapClass(self): # Normally a mirrored value would be the same across devices, but # for a test it is convenient to be able to tell the values apart. - result = values.regroup({_device_str(0): _nested_value("1"), - _device_str(1): _nested_value("2")}, + device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) + result = values.regroup(device_map, + (_nested_value("1"), _nested_value("2")), values.Mirrored) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) @@ -253,11 +253,11 @@ class RegroupAndSelectDeviceTest(test.TestCase): self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored) self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored) - # Also test that we can undo the merge using select_device() + # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), - values.select_device(_device_str(0), result)) + values.select_replica(0, result)) self.assertEqual(_nested_value("2"), - values.select_device(_device_str(1), result)) + values.select_replica(1, result)) # Values are marked as mirrored, so select_device_mirrored() is allowed. self.assertEqual(_nested_value("1"), values.select_device_mirrored(_device_str(0), result)) @@ -267,63 +267,66 @@ class RegroupAndSelectDeviceTest(test.TestCase): def testMirroredContainer(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - v, devices, mirrored = _make_mirrored() - result = values.regroup(dict(zip(devices, v))) + v, device_map, mirrored = _make_mirrored() + result = values.regroup(device_map, v) self.assertIs(mirrored, result) def testSameId(self): foo = object() - result = values.regroup({_device_str(0): ("a", foo), - _device_str(1): ("b", foo)}) + device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) + result = values.regroup(device_map, (("a", foo), ("b", foo))) self.assertIsInstance(result, tuple) self.assertEqual(2, len(result)) self._is_per_replica(result[0], ["a", "b"]) self.assertIs(foo, result[1]) - # Test select_device(), should undo the merge done by regroup(). - result_0 = values.select_device(_device_str(0), result) + # Test select_replica(), should undo the merge done by regroup(). + result_0 = values.select_replica(0, result) self.assertIsInstance(result_0, tuple) self.assertEqual(2, len(result_0)) self.assertEqual("a", result_0[0]) self.assertIs(foo, result_0[1]) - result_1 = values.select_device(_device_str(1), result) + result_1 = values.select_replica(1, result) self.assertIsInstance(result_1, tuple) self.assertEqual(2, len(result_1)) self.assertEqual("b", result_1[0]) self.assertIs(foo, result_1[1]) def testOneDevice(self): - result = values.regroup({_device_str(0): _nested_value("1")}) - # On one device regroup() and select_device() are basically identity. + device_map = values.ReplicaDeviceMap((_device_str(0),)) + result = values.regroup(device_map, (_nested_value("1"),)) + # On one device regroup() and select_replica() are basically identity. self.assertEqual(_nested_value("1"), result) self.assertEqual(_nested_value("1"), - values.select_device(_device_str(0), result)) + values.select_replica(0, result)) # The one exception has to do with MirroredVariables. d = "/device:CPU:0" with ops.device(d): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) - index = {d: v} - mirrored = values.MirroredVariable(index, v, + device_map = values.ReplicaDeviceMap((d,)) + mirrored = values.MirroredVariable(None, device_map, (v,), variable_scope.VariableAggregation.SUM) - result = values.regroup(index) + result = values.regroup(device_map, (v,)) self.assertIs(mirrored, result) def testNamedTupleEstimatorSpec(self): with context.graph_mode(), ops.Graph().as_default(): - created_estimator_specs = {} - to_regroup = {} + devices = [] + created_estimator_specs = [] for device_id in range(3): spec = model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.TRAIN, loss=constant_op.constant(device_id / 2), train_op=array_ops.identity(constant_op.constant(device_id))) - created_estimator_specs[device_id] = spec - to_regroup[_device_str(device_id)] = spec + devices.append(_device_str(device_id)) + created_estimator_specs.append(spec) - merged_estimator_spec = values.regroup(to_regroup) + device_map = values.ReplicaDeviceMap(devices) + merged_estimator_spec = values.regroup( + device_map, created_estimator_specs) self.assertTrue( isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) @@ -337,415 +340,10 @@ class RegroupAndSelectDeviceTest(test.TestCase): # Scaffold is populated by `EstimatorSpec.__new__`. self.assertEqual(created_estimator_specs[device_id].scaffold, merged_estimator_spec.scaffold.get(d)) - # Also test that we can undo the merge using select_device() + # Also test that we can undo the merge using select_replica() self.assertEqual(created_estimator_specs[device_id], - values.select_device(_device_str(device_id), - merged_estimator_spec)) - - -class PerReplicaDatasetTest(test.TestCase): - - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - - def _test_iterator(self, devices, dataset, expected_values): - per_replica_dataset = values.PerReplicaDataset(dataset, devices) - if context.executing_eagerly(): - iterator = per_replica_dataset.make_one_shot_iterator() - else: - iterator = per_replica_dataset.make_initializable_iterator() - self.evaluate([iterator.initializer]) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = self.evaluate( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - self.evaluate([ - values.select_device(d, next_element) for d in devices]) - - @test_util.run_in_graph_and_eager_modes - def testOneDevice(self): - devices = ["/device:CPU:0"] - dataset = dataset_ops.Dataset.range(10) - - expected_values = [[i] for i in range(10)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testMultipleDevices(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset = dataset_ops.Dataset.range(10) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testTupleDataset(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testUnevenDatasetBatches(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset = dataset_ops.Dataset.range(11) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - self._test_iterator(devices, dataset, expected_values) - - def testInitializableIterator(self): - with context.graph_mode(): - devices = ["/device:CPU:0"] - # Using random input since that is only allowed with initializable - # iterator. - dataset = dataset_ops.Dataset.from_tensor_slices( - random_ops.random_uniform((10,))) - - per_replica_dataset = values.PerReplicaDataset(dataset, devices) - iterator = per_replica_dataset.make_initializable_iterator() - - self.evaluate(iterator.initializer) - next_element = iterator.get_next() - for _ in range(10): - self.evaluate(next_element) - - # Should fail after the input is finished. - with self.assertRaises(errors.OutOfRangeError): - self.evaluate(next_element) - - # After re-initializing the iterator, should be able to iterate again. - self.evaluate(iterator.initializer) - for _ in range(10): - self.evaluate(next_element) - - -class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): - - def _test_iterator(self, sess, iterator, devices, expected_values): - next_element = iterator.get_next() - for device in devices: - v = values.select_device(device, next_element) - # The `v` here can be a tuple. - for element in nest.flatten(v): - self.assertTrue(element.device in device) - - for expected_value in expected_values: - actual = sess.run( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, actual) - - with self.assertRaises(errors.OutOfRangeError): - sess.run([values.select_device(d, next_element) for d in devices]) - - def _test_dataset(self, dataset_fn, worker_devices, devices, - expected_values, auto_shard=True): - multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_devices, auto_shard=auto_shard) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - with self.cached_session() as sess: - sess.run(multi_worker_iterator.initializer) - self._test_iterator(sess, multi_worker_iterator, devices, expected_values) - - def _cpu_devices(self): - worker_devices = [ - ("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"])] - devices = [ - "/job:worker/replica:0/task:0/device:CPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ] - return worker_devices, devices - - def _cpu_and_one_gpu_devices(self): - worker_devices = [ - ("/job:worker/replica:0/task:0", [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0" - ]), - ("/job:worker/replica:0/task:1", [ - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ]) - ] - devices = [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0", - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ] - return worker_devices, devices - - def testDataDistributionOneDevicePerWorker(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_devices, devices, - [[0, 1], [2, 3], [4, 5], [6, 7]]) - - def testDataDistributionNoAutoShard(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_dataset(dataset_fn, worker_devices, devices, - [[0, 0], [1, 1], [2, 2], [3, 3]], - auto_shard=False) - - def testDataDistributionTwoDevicePerWorker(self): - if context.num_gpus() < 1: - self.skipTest("A GPU is not available for this test.") - worker_devices, devices = self._cpu_and_one_gpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_devices, devices, - [[0, 2, 1, 3], [4, 6, 5, 7]]) - - def testTupleDataset(self): - worker_devices, devices = self._cpu_devices() - - with context.graph_mode(): - - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(8) - dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [ - [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 8, 2) - ] - self._test_dataset(dataset_fn, worker_devices, devices, - expected_values) - - def testInitializableIterator(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(8) - multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_devices, auto_shard=True) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - - sess.run(multi_worker_iterator.initializer) - self._test_iterator(sess, multi_worker_iterator, devices, - [[0, 1], [2, 3], [4, 5], [6, 7]]) - - # After re-initializing the iterator, should be able to iterate again. - sess.run(multi_worker_iterator.initializer) - self._test_iterator(sess, multi_worker_iterator, devices, - [[0, 1], [2, 3], [4, 5], [6, 7]]) - - def testValueErrorForIterator(self): - # Incompatiable arguments. - with self.assertRaises(ValueError): - values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"}) - - # Test duplicated devices under same worker. - worker_devices, _ = self._cpu_devices() - worker_devices[0][1].append("/job:worker/replica:0/task:0/device:CPU:0") - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_devices, auto_shard=True) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - with self.assertRaises(ValueError): - multi_worker_iterator.get_next() - - -class InputIteratorTestBase(test.TestCase): - - def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, split_batch_by=None): - devices = nest.flatten([ds for _, ds in worker_device_pairs]) - - if input_type == "input_fn": - input_contexts = [ - distribute_lib.InputContext() for _ in worker_device_pairs] - input_fn = lambda _: dataset_fn() - iterator = values.InputFunctionIterator(input_fn, worker_device_pairs, - input_contexts) - else: - iterator = values.DatasetIterator(dataset_fn(), worker_device_pairs, - split_batch_by) - - evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) - - evaluate(control_flow_ops.group(iterator.initialize())) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_device(d, next_element) for d in devices]) - self.assertAllEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - evaluate([values.select_device(d, next_element) for d in devices]) - - # After re-initializing the iterator, should be able to iterate again. - evaluate(control_flow_ops.group(iterator.initialize())) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_device(d, next_element) for d in devices]) - self.assertAllEqual(expected_value, computed_value) - - -class InputIteratorSingleWorkerTest(InputIteratorTestBase, - parameterized.TestCase): - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"])) - def testOneDeviceCPU(self, input_type): - worker_device_pairs = [("", ["/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) - - expected_values = [[i] for i in range(10)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTwoDevicesOneGPUOneCPU(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTupleDataset(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testUnevenDatasetBatches(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(11) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["dataset"], - split_batch_by=[None, 2], - required_gpus=1)) - def testBatchSplitting(self, input_type, split_batch_by): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - batch_size = 10 - dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) - - updated_batch_size = ( - batch_size // split_batch_by if split_batch_by else batch_size) - expected_values = [[range(i, i+updated_batch_size), - range(i+updated_batch_size, i+2*updated_batch_size)] - for i in range(0, 100, updated_batch_size*2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, - split_batch_by=split_batch_by) - - -class InputIteratorMultiWorkerTest( - multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, - parameterized.TestCase): - - def _cpu_devices(self): - return [ - ("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"])] - - def _cpu_and_one_gpu_devices(self): - return [ - ("/job:worker/replica:0/task:0", [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0" - ]), - ("/job:worker/replica:0/task:1", [ - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ]) - ] - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"])) - def testOneDevicePerWorker(self, input_type): - worker_devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_iterator(input_type, dataset_fn, worker_devices, - [[0, 0], [1, 1], [2, 2], [3, 3]], sess) - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTwoDevicesPerWorker(self, input_type): - worker_devices = self._cpu_and_one_gpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_iterator(input_type, dataset_fn, worker_devices, - [[0, 1, 0, 1], [2, 3, 2, 3]], sess) - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"])) - def testTupleDataset(self, input_type): - worker_devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(4) - dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] - self._test_iterator(input_type, dataset_fn, worker_devices, - expected_values, sess) + values.select_replica(device_id, + merged_estimator_spec)) class MirroredVariableTest(test.TestCase, parameterized.TestCase): @@ -768,8 +366,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): def testVariableOnAnotherDevice(self): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) - index = {"/job:foo/device:CPU:0": v} - mirrored = values.MirroredVariable(index, v, + device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",)) + mirrored = values.MirroredVariable(None, device_map, (v,), variable_scope.VariableAggregation.MEAN) self.assertEqual(v.name, mirrored.name) @@ -797,7 +395,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): self.skipTest("A GPU is not available for this test in eager mode.") with self.cached_session(config=self.config) as sess: - v, devices, mirrored = _make_mirrored() + v, device_map, mirrored = _make_mirrored() + devices = device_map.all_devices # Overwrite the initial values. self._assign_mirrored(devices, v, [3., 4.]) @@ -815,7 +414,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): def _save_mirrored(self): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: - v, devices, mirrored = _make_mirrored() + v, device_map, mirrored = _make_mirrored() + devices = device_map.all_devices # Overwrite the initial values. self._assign_mirrored(devices, v, [3., 4.]) @@ -860,7 +460,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): def _restore_mirrored(self, save_path): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: - v, devices, mirrored = _make_mirrored() + v, device_map, mirrored = _make_mirrored() + devices = device_map.all_devices # Overwrite the initial values. self._assign_mirrored(devices, v, [7., 8.]) @@ -904,25 +505,24 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): with ops.device("/device:GPU:0"): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) - mirrored = values.MirroredVariable({ - "/device:GPU:0": v - }, v, variable_scope.VariableAggregation.MEAN) + mirrored = values.MirroredVariable( + distribution, values.ReplicaDeviceMap(("/device:GPU:0",)), (v,), + variable_scope.VariableAggregation.MEAN) sess.run(variables_lib.global_variables_initializer()) sess.run({"complicated": mirrored}) -_devices = ["/device:GPU:0", "/device:CPU:0"] +_devices = ("/device:GPU:0", "/device:CPU:0") -def _make_replica_local(method): +def _make_replica_local(method, strategy=None): + device_map = values.ReplicaDeviceMap(_devices) v = [] - index = {} for d, n, init in zip(_devices, ["v", "v/replica"], [1., 2.]): with ops.device(d): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) - index[d] = v[-1] - replica_local = values.ReplicaLocalVariable(index, v[0], method) + replica_local = values.ReplicaLocalVariable(strategy, device_map, v, method) return v, replica_local @@ -948,9 +548,9 @@ class ReplicaLocalVariablePropertiesTest(test.TestCase): def testVariableOnAnotherDevice(self): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) - index = {"/job:foo/device:CPU:0": v} + device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",)) replica_local = values.ReplicaLocalVariable( - index, v, variable_scope.VariableAggregation.MEAN) + None, device_map, (v,), variable_scope.VariableAggregation.MEAN) self.assertEqual(v.name, replica_local.name) self.assertEqual(v.dtype, replica_local.dtype) @@ -997,7 +597,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution): with self.cached_session() as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.SUM) + variable_scope.VariableAggregation.SUM, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) @@ -1020,7 +620,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): with self.cached_session() as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.MEAN) + variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) @@ -1040,7 +640,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.MEAN) + variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) @@ -1056,7 +656,8 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): def _save_replica_local_sum(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: - v, replica_local = _make_replica_local("sum") + v, replica_local = _make_replica_local( + variable_scope.VariableAggregation.SUM, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [1.5, 2.]) @@ -1103,7 +704,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.MEAN) + variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [7., 8.]) @@ -1118,7 +719,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.SUM) + variable_scope.VariableAggregation.SUM, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [7., 8.]) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 3079175015a9aee1625404902070df8f13b2089c..c2300286d3be4bb757dac588623c47044a1a9db5 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -822,7 +822,7 @@ cuda_py_test( cuda_py_test( name = "affine_test", - size = "large", + size = "medium", srcs = ["python/kernel_tests/bijectors/affine_test.py"], additional_deps = [ ":bijectors_py", @@ -837,7 +837,7 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], - shard_count = 5, + shard_count = 10, tags = ["noasan"], # times out b/63678675 ) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 452628257ea96713453bf2aa32b5baa9d6d0cb86..1006dfac49f36baa7cf5136f6f2982e3fd965298 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -249,9 +249,9 @@ class InverseGamma(distribution.Distribution): `self.allow_nan_stats` is `False`, an exception will be raised rather than returning `NaN`.""") def _variance(self): - var = (math_ops.square(self.rate) - / math_ops.square(self.concentration - 1.) - / (self.concentration - 2.)) + var = ( + math_ops.square(self.rate) / math_ops.squared_difference( + self.concentration, 1.) / (self.concentration - 2.)) if self.allow_nan_stats: nan = array_ops.fill( self.batch_shape_tensor(), diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index 978e627d6638ddeea9df288d389354f0ac53d115..19e99e03803e7f4cdfdb023feb04daaba68eceed 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -300,7 +300,7 @@ def percentile(x, raise ValueError("Argument 'interpolation' must be in %s. Found %s" % (allowed_interpolations, interpolation)) - with ops.name_scope(name, [x, q]): + with ops.name_scope(name, values=[x, q]): x = ops.convert_to_tensor(x, name="x") # Double is needed here and below, else we get the wrong index if the array # is huge along axis. diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 77052a75a70bec1162feb2b126d247924b3a2e36..d441e4735b64fe1176e77a978d281d46a7b287ab 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -15,7 +15,6 @@ py_library( ":metrics", ":network", ":parameter_server", - ":remote", ":saver", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", @@ -31,6 +30,7 @@ py_library( "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:execution_callbacks", "//tensorflow/python/eager:function", + "//tensorflow/python/eager:remote", ], ) @@ -144,7 +144,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "//tensorflow/python/eager:function", - "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/tracking:base", ], ) @@ -238,24 +238,12 @@ py_test( ], ) -py_library( - name = "remote", - srcs = ["remote.py"], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:platform", - "//tensorflow/python/eager:context", - ], -) - cuda_py_test( name = "remote_test", srcs = ["remote_test.py"], additional_deps = [ ":parameter_server", - ":remote", + "//tensorflow/python/eager:remote", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/python:array_ops", "//tensorflow/python:client", diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 257d02057ae0d280074559aa9e97725bf5cc3fd0..48925b1bfacc6b59c210b2fb4b53a9a1a851673f 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -37,7 +37,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.training import checkpoint_management -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils class IteratorTest(test.TestCase): @@ -200,13 +200,6 @@ class IteratorTest(test.TestCase): y = math_ops.add(x, x) self.assertAllEqual([0., 2.], y.numpy()) - def testGpuDefinedDataset(self): - with ops.device(test.gpu_device_name()): - ds = Dataset.from_tensors([0., 1.]) - for x in ds: - y = math_ops.add(x, x) - self.assertAllEqual([0., 2.], y.numpy()) - def testOverrideThreadPool(self): def get_thread_id(_): @@ -245,7 +238,7 @@ class IteratorTest(test.TestCase): dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) dataset = dataset.map(math_ops.square).batch(2) iterator = datasets.Iterator(dataset) - checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint = trackable_utils.Checkpoint(iterator=iterator) self.assertAllEqual([1, 4], iterator.get_next().numpy()) save_path = checkpoint.save(checkpoint_prefix) self.assertAllEqual([9, 16], iterator.get_next().numpy()) @@ -264,7 +257,7 @@ class IteratorTest(test.TestCase): dataset_2 = Dataset.range(10) iterator_3 = datasets.Iterator(dataset_2) - checkpoint = checkpointable_utils.Checkpoint( + checkpoint = trackable_utils.Checkpoint( iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) self.assertAllEqual([1, 4], iterator_1.get_next().numpy()) self.assertEqual(0, iterator_3.get_next().numpy()) @@ -286,7 +279,7 @@ class IteratorTest(test.TestCase): dataset = Dataset.range(3) iterator = datasets.Iterator(dataset) - checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint = trackable_utils.Checkpoint(iterator=iterator) self.assertEqual(0, iterator.get_next().numpy()) self.assertEqual(1, iterator.get_next().numpy()) save_path = checkpoint.save(checkpoint_prefix) @@ -300,7 +293,7 @@ class IteratorTest(test.TestCase): dataset = Dataset.range(10) for i in range(5): iterator = datasets.Iterator(dataset) - checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint = trackable_utils.Checkpoint(iterator=iterator) checkpoint.restore(checkpoint_management.latest_checkpoint( checkpoint_directory)) for j in range(2): diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 97c299a911c9180bf69faa0fa46527e80eada790..3e0881754c750f4d36e2e4dd8b80835b031c658c 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -6,16 +6,16 @@ package(default_visibility = ["//tensorflow:internal"]) py_library( name = "examples_pip", deps = [ - "//tensorflow/contrib/eager/python/examples/densenet", - "//tensorflow/contrib/eager/python/examples/gan:mnist", + "//tensorflow/contrib/eager/python/examples/densenet:densenet_lib", + "//tensorflow/contrib/eager/python/examples/gan:mnist_lib", "//tensorflow/contrib/eager/python/examples/l2hmc", "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets", - "//tensorflow/contrib/eager/python/examples/linear_regression", + "//tensorflow/contrib/eager/python/examples/linear_regression:linear_regression_lib", "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/revnet", "//tensorflow/contrib/eager/python/examples/revnet:config", - "//tensorflow/contrib/eager/python/examples/rnn_colorbot", - "//tensorflow/contrib/eager/python/examples/rnn_ptb", + "//tensorflow/contrib/eager/python/examples/rnn_colorbot:rnn_colorbot_lib", + "//tensorflow/contrib/eager/python/examples/rnn_ptb:rnn_ptb_lib", "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/densenet/BUILD b/tensorflow/contrib/eager/python/examples/densenet/BUILD index e2154fcc5fcf774dcd52285d9442dfd5073a4992..fbb5daf230bb79f08a3d071062ddc0e8507ab324 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/BUILD +++ b/tensorflow/contrib/eager/python/examples/densenet/BUILD @@ -9,6 +9,13 @@ py_binary( name = "densenet", srcs = ["densenet.py"], srcs_version = "PY2AND3", + deps = [":densenet_lib"], +) + +py_library( + name = "densenet_lib", + srcs = ["densenet.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -17,33 +24,37 @@ py_binary( cuda_py_test( name = "densenet_test", - size = "large", + size = "medium", srcs = ["densenet_test.py"], additional_deps = [ ":densenet", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "no_pip", "optonly", + "oss_serial", ], ) cuda_py_test( name = "densenet_graph_test", - size = "large", + size = "medium", srcs = ["densenet_graph_test.py"], additional_deps = [ ":densenet", "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "no_pip", "noasan", "nomsan", "notsan", "optonly", + "oss_serial", ], ) diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD index d64c8eb9ce122fa277567b2fbc632abfbc72df64..d99a519112787bad664232983208279cfb4d0036 100644 --- a/tensorflow/contrib/eager/python/examples/gan/BUILD +++ b/tensorflow/contrib/eager/python/examples/gan/BUILD @@ -9,6 +9,13 @@ py_binary( name = "mnist", srcs = ["mnist.py"], srcs_version = "PY2AND3", + deps = [":mnist_lib"], +) + +py_library( + name = "mnist_lib", + srcs = ["mnist.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -20,7 +27,7 @@ cuda_py_test( name = "mnist_test", srcs = ["mnist_test.py"], additional_deps = [ - ":mnist", + ":mnist_lib", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], @@ -30,7 +37,7 @@ cuda_py_test( name = "mnist_graph_test", srcs = ["mnist_graph_test.py"], additional_deps = [ - ":mnist", + ":mnist_lib", "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb index 1a08cc0fd06516be4af5c2b0b46a3ffcf9101e95..e1a02db76f705414a34d232022f50124a5a6a3ed 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb @@ -13,11 +13,13 @@ "\n", "# Convolutional VAE: An example with tf.keras and eager\n", "\n", + "This example has moved:\n", + "\n", "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/cvae.ipynb\"\u003e\n", " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/cvae.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] }, { @@ -28,604 +30,14 @@ }, "source": [ "![evolution of output during training](https://tensorflow.org/images/autoencoders/cvae.gif)\n", - "\n", - "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) by training a Variational Autoencoder. (VAE, [[1]](https://arxiv.org/abs/1312.6114), [[2]](https://arxiv.org/abs/1401.4082)).\n", "\n" ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "P-JuIu2N_SQf" - }, - "outputs": [], - "source": [ - "# to generate gifs\n", - "!pip install imageio" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "e1_Y75QXJS6h" - }, - "source": [ - "## Import TensorFlow and enable Eager execution" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "YfIk2es3hJEd" - }, - "outputs": [], - "source": [ - "from __future__ import absolute_import, division, print_function\n", - "\n", - "# Import TensorFlow \u003e= 1.9 and enable eager execution\n", - "import tensorflow as tf\n", - "tfe = tf.contrib.eager\n", - "tf.enable_eager_execution()\n", - "\n", - "import os\n", - "import time\n", - "import numpy as np\n", - "import glob\n", - "import matplotlib.pyplot as plt\n", - "import PIL\n", - "import imageio\n", - "from IPython import display" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "iYn4MdZnKCey" - }, - "source": [ - "## Load the MNIST dataset\n", - "Each MNIST image is originally a vector of 784 integers, each of which is between 0-255 and represents the intensity of a pixel. We model each pixel with a Bernoulli distribution in our model, and we statically binarize the dataset." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "a4fYMGxGhrna" - }, - "outputs": [], - "source": [ - "(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "NFC2ghIdiZYE" - }, - "outputs": [], - "source": [ - "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", - "test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32')\n", - "\n", - "# Normalizing the images to the range of [0., 1.]\n", - "train_images /= 255.\n", - "test_images /= 255.\n", - "\n", - "# Binarization\n", - "train_images[train_images \u003e= .5] = 1.\n", - "train_images[train_images \u003c .5] = 0.\n", - "test_images[test_images \u003e= .5] = 1.\n", - "test_images[test_images \u003c .5] = 0." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "S4PIDhoDLbsZ" - }, - "outputs": [], - "source": [ - "TRAIN_BUF = 60000\n", - "BATCH_SIZE = 100\n", - "\n", - "TEST_BUF = 10000" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "PIGN6ouoQxt3" - }, - "source": [ - "## Use *tf.data* to create batches and shuffle the dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "-yKCCQOoJ7cn" - }, - "outputs": [], - "source": [ - "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)\n", - "test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TEST_BUF).batch(BATCH_SIZE)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "THY-sZMiQ4UV" - }, - "source": [ - "## Wire up the generative and inference network with *tf.keras.Sequential*\n", - "\n", - "In our VAE example, we use two small ConvNets for the generative and inference network. Since these neural nets are small, we use `tf.keras.Sequential` to simplify our code. Let $x$ and $z$ denote the observation and latent variable respectively in the following descriptions. \n", - "\n", - "### Generative Network\n", - "This defines the generative model which takes a latent encoding as input, and outputs the parameters for a conditional distribution of the observation, i.e. $p(x|z)$. Additionally, we use a unit Gaussian prior $p(z)$ for the latent variable.\n", - "\n", - "### Inference Network\n", - "This defines an approximate posterior distribution $q(z|x)$, which takes as input an observation and outputs a set of parameters for the conditional distribution of the latent representation. In this example, we simply model this distribution as a diagonal Gaussian. In this case, the inference network outputs the mean and log-variance parameters of a factorized Gaussian (log-variance instead of the variance directly is for numerical stability).\n", - "\n", - "### Reparameterization Trick\n", - "During optimization, we can sample from $q(z|x)$ by first sampling from a unit Gaussian, and then multiplying by the standard deviation and adding the mean. This ensures the gradients could pass through the sample to the inference network parameters.\n", - "\n", - "### Network architecture\n", - "For the inference network, we use two convolutional layers followed by a fully-connected layer. In the generative network, we mirror this architecture by using a fully-connected layer followed by three convolution transpose layers (a.k.a. deconvolutional layers in some contexts). Note, it's common practice to avoid using batch normalization when training VAEs, since the additional stochasticity due to using mini-batches may aggravate instability on top of the stochasticity from sampling." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "VGLbvBEmjK0a" - }, - "outputs": [], - "source": [ - "class CVAE(tf.keras.Model):\n", - " def __init__(self, latent_dim):\n", - " super(CVAE, self).__init__()\n", - " self.latent_dim = latent_dim\n", - " self.inference_net = tf.keras.Sequential(\n", - " [\n", - " tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),\n", - " tf.keras.layers.Conv2D(\n", - " filters=32, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", - " tf.keras.layers.Conv2D(\n", - " filters=64, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", - " tf.keras.layers.Flatten(),\n", - " # No activation\n", - " tf.keras.layers.Dense(latent_dim + latent_dim),\n", - " ]\n", - " )\n", - "\n", - " self.generative_net = tf.keras.Sequential(\n", - " [\n", - " tf.keras.layers.InputLayer(input_shape=(latent_dim,)),\n", - " tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),\n", - " tf.keras.layers.Reshape(target_shape=(7, 7, 32)),\n", - " tf.keras.layers.Conv2DTranspose(\n", - " filters=64,\n", - " kernel_size=3,\n", - " strides=(2, 2),\n", - " padding=\"SAME\",\n", - " activation=tf.nn.relu),\n", - " tf.keras.layers.Conv2DTranspose(\n", - " filters=32,\n", - " kernel_size=3,\n", - " strides=(2, 2),\n", - " padding=\"SAME\",\n", - " activation=tf.nn.relu),\n", - " # No activation\n", - " tf.keras.layers.Conv2DTranspose(\n", - " filters=1, kernel_size=3, strides=(1, 1), padding=\"SAME\"),\n", - " ]\n", - " )\n", - "\n", - " def sample(self, eps=None):\n", - " if eps is None:\n", - " eps = tf.random_normal(shape=(100, self.latent_dim))\n", - " return self.decode(eps, apply_sigmoid=True)\n", - "\n", - " def encode(self, x):\n", - " mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)\n", - " return mean, logvar\n", - "\n", - " def reparameterize(self, mean, logvar):\n", - " eps = tf.random_normal(shape=mean.shape)\n", - " return eps * tf.exp(logvar * .5) + mean\n", - "\n", - " def decode(self, z, apply_sigmoid=False):\n", - " logits = self.generative_net(z)\n", - " if apply_sigmoid:\n", - " probs = tf.sigmoid(logits)\n", - " return probs\n", - "\n", - " return logits" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "0FMYgY_mPfTi" - }, - "source": [ - "## Define the loss function and the optimizer\n", - "\n", - "VAEs train by maximizing the evidence lower bound (ELBO) on the marginal log-likelihood:\n", - "\n", - "$$\\log p(x) \\ge \\text{ELBO} = \\mathbb{E}_{q(z|x)}\\left[\\log \\frac{p(x, z)}{q(z|x)}\\right].$$\n", - "\n", - "In practice, we optimize the single sample Monte Carlo estimate of this expectation:\n", - "\n", - "$$\\log p(x| z) + \\log p(z) - \\log q(z|x),$$\n", - "where $z$ is sampled from $q(z|x)$.\n", - "\n", - "**Note**: we could also analytically compute the KL term, but here we incorporate all three terms in the Monte Carlo estimator for simplicity." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "iWCn_PVdEJZ7" - }, - "outputs": [], - "source": [ - "def log_normal_pdf(sample, mean, logvar, raxis=1):\n", - " log2pi = tf.log(2. * np.pi)\n", - " return tf.reduce_sum(\n", - " -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),\n", - " axis=raxis)\n", - "\n", - "def compute_loss(model, x):\n", - " mean, logvar = model.encode(x)\n", - " z = model.reparameterize(mean, logvar)\n", - " x_logit = model.decode(z)\n", - "\n", - " cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)\n", - " logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])\n", - " logpz = log_normal_pdf(z, 0., 0.)\n", - " logqz_x = log_normal_pdf(z, mean, logvar)\n", - " return -tf.reduce_mean(logpx_z + logpz - logqz_x)\n", - "\n", - "def compute_gradients(model, x):\n", - " with tf.GradientTape() as tape:\n", - " loss = compute_loss(model, x)\n", - " return tape.gradient(loss, model.trainable_variables), loss\n", - "\n", - "optimizer = tf.train.AdamOptimizer(1e-4)\n", - "def apply_gradients(optimizer, gradients, variables, global_step=None):\n", - " optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Rw1fkAczTQYh" - }, - "source": [ - "## Training\n", - "\n", - "* We start by iterating over the dataset\n", - "* During each iteration, we pass the image to the encoder to obtain a set of mean and log-variance parameters of the approximate posterior $q(z|x)$\n", - "* We then apply the *reparameterization trick* to sample from $q(z|x)$\n", - "* Finally, we pass the reparameterized samples to the decoder to obtain the logits of the generative distribution $p(x|z)$\n", - "* **Note:** Since we use the dataset loaded by keras with 60k datapoints in the training set and 10k datapoints in the test set, our resulting ELBO on the test set is slightly higher than reported results in the literature which uses dynamic binarization of Larochelle's MNIST.\n", - "\n", - "## Generate Images\n", - "\n", - "* After training, it is time to generate some images\n", - "* We start by sampling a set of latent vectors from the unit Gaussian prior distribution $p(z)$\n", - "* The generator will then convert the latent sample $z$ to logits of the observation, giving a distribution $p(x|z)$\n", - "* Here we plot the probabilities of Bernoulli distributions\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "NS2GWywBbAWo" - }, - "outputs": [], - "source": [ - "epochs = 100\n", - "latent_dim = 50\n", - "num_examples_to_generate = 16\n", - "\n", - "# keeping the random vector constant for generation (prediction) so\n", - "# it will be easier to see the improvement.\n", - "random_vector_for_generation = tf.random_normal(\n", - " shape=[num_examples_to_generate, latent_dim])\n", - "model = CVAE(latent_dim)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "RmdVsmvhPxyy" - }, - "outputs": [], - "source": [ - "def generate_and_save_images(model, epoch, test_input):\n", - " predictions = model.sample(test_input)\n", - " fig = plt.figure(figsize=(4,4))\n", - "\n", - " for i in range(predictions.shape[0]):\n", - " plt.subplot(4, 4, i+1)\n", - " plt.imshow(predictions[i, :, :, 0], cmap='gray')\n", - " plt.axis('off')\n", - "\n", - " # tight_layout minimizes the overlap between 2 sub-plots\n", - " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "2M7LmLtGEMQJ" - }, - "outputs": [], - "source": [ - "generate_and_save_images(model, 0, random_vector_for_generation)\n", - "\n", - "for epoch in range(1, epochs + 1):\n", - " start_time = time.time()\n", - " for train_x in train_dataset:\n", - " gradients, loss = compute_gradients(model, train_x)\n", - " apply_gradients(optimizer, gradients, model.trainable_variables)\n", - " end_time = time.time()\n", - "\n", - " if epoch % 1 == 0:\n", - " loss = tfe.metrics.Mean()\n", - " for test_x in test_dataset:\n", - " loss(compute_loss(model, test_x))\n", - " elbo = -loss.result()\n", - " display.clear_output(wait=False)\n", - " print('Epoch: {}, Test set ELBO: {}, '\n", - " 'time elapse for current epoch {}'.format(epoch,\n", - " elbo,\n", - " end_time - start_time))\n", - " generate_and_save_images(\n", - " model, epoch, random_vector_for_generation)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "P4M_vIbUi7c0" - }, - "source": [ - "### Display an image using the epoch number" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "WfO5wCdclHGL" - }, - "outputs": [], - "source": [ - "def display_image(epoch_no):\n", - " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "5x3q9_Oe5q0A" - }, - "outputs": [], - "source": [ - "display_image(epochs) # Display images" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "NywiH3nL8guF" - }, - "source": [ - "### Generate a GIF of all the saved images." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "IGKQgENQ8lEI" - }, - "outputs": [], - "source": [ - "with imageio.get_writer('cvae.gif', mode='I') as writer:\n", - " filenames = glob.glob('image*.png')\n", - " filenames = sorted(filenames)\n", - " last = -1\n", - " for i,filename in enumerate(filenames):\n", - " frame = 2*(i**0.5)\n", - " if round(frame) \u003e round(last):\n", - " last = frame\n", - " else:\n", - " continue\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " \n", - "# this is a hack to display the gif inside the notebook\n", - "os.system('cp cvae.gif cvae.gif.png')" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "uV0yiKpzNP1b" - }, - "outputs": [], - "source": [ - "display.Image(filename=\"cvae.gif.png\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "yQXO_dlXkKsT" - }, - "source": [ - "To downlod the animation from Colab uncomment the code below:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "4fSJS3m5HLFM" - }, - "outputs": [], - "source": [ - "#from google.colab import files\n", - "#files.download('cvae.gif')" - ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], - "default_view": {}, "name": "cvae.ipynb", "private_outputs": true, "provenance": [ @@ -635,8 +47,7 @@ } ], "toc_visible": true, - "version": "0.3.2", - "views": {} + "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb index 78fcd397087fd1fd64aebed7ac3b5c6b2f45c450..53767058838459e56215d286e9f8f8eb66287147 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb @@ -1,26 +1,11 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "dcgan.ipynb", - "version": "0.3.2", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python2", - "display_name": "Python 2" - }, - "accelerator": "GPU" - }, "cells": [ { + "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0TD5ZrvEMbhZ" }, - "cell_type": "markdown", "source": [ "**Copyright 2018 The TensorFlow Authors**.\n", "\n", @@ -28,851 +13,39 @@ "\n", "# Generating Handwritten Digits with DCGAN\n", "\n", - "
\n", - "\n", - " Run in Google Colab \n", - "\n", - "View source on GitHub
" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "ITZuApL56Mny" - }, - "cell_type": "markdown", - "source": [ - "This tutorial demonstrates how to generate images of handwritten digits using a Deep Convolutional Generative Adversarial Network ([DCGAN](https://arxiv.org/pdf/1511.06434.pdf)). The code is written in [tf.keras](https://www.tensorflow.org/programmers_guide/keras) with [eager execution](https://www.tensorflow.org/programmers_guide/eager) enabled. " - ] - }, - { - "metadata": { - "colab_type": "toc", - "id": "x2McrO9bMyLN" - }, - "cell_type": "markdown", - "source": [ - ">[Generating Handwritten Digits with DCGAN](#scrollTo=0TD5ZrvEMbhZ)\n", - "\n", - ">>[What are GANs?](#scrollTo=2MbKJY38Puy9)\n", - "\n", - ">>>[Import TensorFlow and enable eager execution](#scrollTo=e1_Y75QXJS6h)\n", - "\n", - ">>>[Load the dataset](#scrollTo=iYn4MdZnKCey)\n", - "\n", - ">>>[Use tf.data to create batches and shuffle the dataset](#scrollTo=PIGN6ouoQxt3)\n", - "\n", - ">>[Create the models](#scrollTo=THY-sZMiQ4UV)\n", - "\n", - ">>>[The Generator Model](#scrollTo=-tEyxE-GMC48)\n", - "\n", - ">>>[The Discriminator model](#scrollTo=D0IKnaCtg6WE)\n", - "\n", - ">>[Define the loss functions and the optimizer](#scrollTo=0FMYgY_mPfTi)\n", - "\n", - ">>>[Generator loss](#scrollTo=Jd-3GCUEiKtv)\n", - "\n", - ">>>[Discriminator loss](#scrollTo=PKY_iPSPNWoj)\n", - "\n", - ">>[Set up GANs for Training](#scrollTo=Rw1fkAczTQYh)\n", - "\n", - ">>[Train the GANs](#scrollTo=dZrd4CdjR-Fp)\n", - "\n", - ">>[Generated images](#scrollTo=P4M_vIbUi7c0)\n", + "This example has moved.\n", "\n", - ">>[Learn more about GANs](#scrollTo=k6qC-SbjK0yW)\n", - "\n" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/dcgan.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/blob/master/site/en/r2/tutorials/generative/dcgan.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] }, { + "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "2MbKJY38Puy9" }, - "cell_type": "markdown", "source": [ - "## What are GANs?\n", - "GANs, or [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661), are a framework for estimating generative models. Two models are trained simultaneously by an adversarial process: a Generator, which is responsible for generating data (say, images), and a Discriminator, which is responsible for estimating the probability that an image was drawn from the training data (the image is real), or was produced by the Generator (the image is fake). During training, the Generator becomes progressively better at generating images, until the Discriminator is no longer able to distinguish real images from fake. \n", - "\n", - "![alt text](https://github.com/margaretmz/tensorflow/blob/margaret-dcgan/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png?raw=1)\n", - "\n", - "We will demonstrate this process end-to-end on MNIST. Below is an animation that shows a series of images produced by the Generator as it was trained for 50 epochs. Overtime, the generated images become increasingly difficult to distinguish from the training set.\n", - "\n", - "To learn more about GANs, we recommend MIT's [Intro to Deep Learning](http://introtodeeplearning.com/) course, which includes a lecture on Deep Generative Models ([video](https://youtu.be/JVb54xhEw6Y) | [slides](http://introtodeeplearning.com/materials/2018_6S191_Lecture4.pdf)). Now, let's head to the code!\n", - "\n", "![sample output](https://tensorflow.org/images/gan/dcgan.gif)" ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "dcgan.ipynb", + "provenance": [], + "version": "0.3.2" }, - { - "metadata": { - "colab_type": "code", - "id": "u_2z-B3piVsw", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# Install imgeio in order to generate an animated gif showing the image generating process\n", - "!pip install imageio" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "e1_Y75QXJS6h" - }, - "cell_type": "markdown", - "source": [ - "### Import TensorFlow and enable eager execution" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "YfIk2es3hJEd", - "colab": {} - }, - "cell_type": "code", - "source": [ - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "import glob\n", - "import imageio\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import os\n", - "import PIL\n", - "import time\n", - "\n", - "from IPython import display" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "iYn4MdZnKCey" - }, - "cell_type": "markdown", - "source": [ - "### Load the dataset\n", - "\n", - "We are going to use the MNIST dataset to train the generator and the discriminator. The generator will generate handwritten digits resembling the MNIST data." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "a4fYMGxGhrna", - "colab": {} - }, - "cell_type": "code", - "source": [ - "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "NFC2ghIdiZYE", - "colab": {} - }, - "cell_type": "code", - "source": [ - "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", - "train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "S4PIDhoDLbsZ", - "colab": {} - }, - "cell_type": "code", - "source": [ - "BUFFER_SIZE = 60000\n", - "BATCH_SIZE = 256" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "PIGN6ouoQxt3" - }, - "cell_type": "markdown", - "source": [ - "### Use tf.data to create batches and shuffle the dataset" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "-yKCCQOoJ7cn", - "colab": {} - }, - "cell_type": "code", - "source": [ - "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "THY-sZMiQ4UV" - }, - "cell_type": "markdown", - "source": [ - "## Create the models\n", - "\n", - "We will use tf.keras [Sequential API](https://www.tensorflow.org/guide/keras#sequential_model) to define the generator and discriminator models." - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "-tEyxE-GMC48" - }, - "cell_type": "markdown", - "source": [ - "### The Generator Model\n", - "\n", - "The generator is responsible for creating convincing images that are good enough to fool the discriminator. The network architecture for the generator consists of [Conv2DTranspose](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2DTranspose) (Upsampling) layers. We start with a fully connected layer and upsample the image two times in order to reach the desired image size of 28x28x1. We increase the width and height, and reduce the depth as we move through the layers in the network. We use [Leaky ReLU](https://www.tensorflow.org/api_docs/python/tf/keras/layers/LeakyReLU) activation for each layer except for the last one where we use a tanh activation." - ] - }, - { - "metadata": { - "id": "6bpTcDqoLWjY", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def make_generator_model():\n", - " model = tf.keras.Sequential()\n", - " model.add(tf.keras.layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))\n", - " model.add(tf.keras.layers.BatchNormalization())\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - " \n", - " model.add(tf.keras.layers.Reshape((7, 7, 256)))\n", - " assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size\n", - " \n", - " model.add(tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))\n", - " assert model.output_shape == (None, 7, 7, 128) \n", - " model.add(tf.keras.layers.BatchNormalization())\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - "\n", - " model.add(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))\n", - " assert model.output_shape == (None, 14, 14, 64) \n", - " model.add(tf.keras.layers.BatchNormalization())\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - "\n", - " model.add(tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))\n", - " assert model.output_shape == (None, 28, 28, 1)\n", - " \n", - " return model" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "D0IKnaCtg6WE" - }, - "cell_type": "markdown", - "source": [ - "### The Discriminator model\n", - "\n", - "The discriminator is responsible for distinguishing fake images from real images. It's similar to a regular CNN-based image classifier." - ] - }, - { - "metadata": { - "id": "dw2tPLmk2pEP", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def make_discriminator_model():\n", - " model = tf.keras.Sequential()\n", - " model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'))\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - " model.add(tf.keras.layers.Dropout(0.3))\n", - " \n", - " model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - " model.add(tf.keras.layers.Dropout(0.3))\n", - " \n", - " model.add(tf.keras.layers.Flatten())\n", - " model.add(tf.keras.layers.Dense(1))\n", - " \n", - " return model" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "gDkA05NE6QMs", - "colab": {} - }, - "cell_type": "code", - "source": [ - "generator = make_generator_model()\n", - "discriminator = make_discriminator_model()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "0FMYgY_mPfTi" - }, - "cell_type": "markdown", - "source": [ - "## Define the loss functions and the optimizer\n", - "\n", - "Let's define the loss functions and the optimizers for the generator and the discriminator.\n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "Jd-3GCUEiKtv" - }, - "cell_type": "markdown", - "source": [ - "### Generator loss\n", - "The generator loss is a sigmoid cross entropy loss of the generated images and an array of ones, since the generator is trying to generate fake images that resemble the real images." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "90BIcCKcDMxz", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def generator_loss(generated_output):\n", - " return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "PKY_iPSPNWoj" - }, - "cell_type": "markdown", - "source": [ - "### Discriminator loss\n", - "\n", - "The discriminator loss function takes two inputs: real images, and generated images. Here is how to calculate the discriminator loss:\n", - "1. Calculate real_loss which is a sigmoid cross entropy loss of the real images and an array of ones (since these are the real images).\n", - "2. Calculate generated_loss which is a sigmoid cross entropy loss of the generated images and an array of zeros (since these are the fake images).\n", - "3. Calculate the total_loss as the sum of real_loss and generated_loss." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "wkMNfBWlT-PV", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def discriminator_loss(real_output, generated_output):\n", - " # [1,1,...,1] with real output since it is true and we want our generated examples to look like it\n", - " real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)\n", - "\n", - " # [0,0,...,0] with generated images since they are fake\n", - " generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output), logits=generated_output)\n", - "\n", - " total_loss = real_loss + generated_loss\n", - "\n", - " return total_loss" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "MgIc7i0th_Iu" - }, - "cell_type": "markdown", - "source": [ - "The discriminator and the generator optimizers are different since we will train two networks separately." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "iWCn_PVdEJZ7", - "colab": {} - }, - "cell_type": "code", - "source": [ - "generator_optimizer = tf.train.AdamOptimizer(1e-4)\n", - "discriminator_optimizer = tf.train.AdamOptimizer(1e-4)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "mWtinsGDPJlV" - }, - "cell_type": "markdown", - "source": [ - "**Checkpoints (Object-based saving)**" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "CA1w-7s2POEy", - "colab": {} - }, - "cell_type": "code", - "source": [ - "checkpoint_dir = './training_checkpoints'\n", - "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", - "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n", - " discriminator_optimizer=discriminator_optimizer,\n", - " generator=generator,\n", - " discriminator=discriminator)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "Rw1fkAczTQYh" - }, - "cell_type": "markdown", - "source": [ - "## Set up GANs for Training\n", - "\n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "5QC5BABamh_c" - }, - "cell_type": "markdown", - "source": [ - "Now it's time to put together the generator and discriminator to set up the Generative Adversarial Networks, as you see in the diagam at the beginning of the tutorial." - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "Ff6oN6PZX27n" - }, - "cell_type": "markdown", - "source": [ - "**Define training parameters**" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "NS2GWywBbAWo", - "colab": {} - }, - "cell_type": "code", - "source": [ - "EPOCHS = 50\n", - "noise_dim = 100\n", - "num_examples_to_generate = 16\n", - "\n", - "# We'll re-use this random vector used to seed the generator so\n", - "# it will be easier to see the improvement over time.\n", - "random_vector_for_generation = tf.random_normal([num_examples_to_generate,\n", - " noise_dim])" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "jylSonrqSWfi" - }, - "cell_type": "markdown", - "source": [ - "**Define training method**\n", - "\n", - "We start by iterating over the dataset. The generator is given a random vector as an input which is processed to output an image looking like a handwritten digit. The discriminator is then shown the real MNIST images as well as the generated images.\n", - "\n", - "Next, we calculate the generator and the discriminator loss. Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables." - ] - }, - { - "metadata": { - "id": "3t5ibNo05jCB", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def train_step(images):\n", - " # generating noise from a normal distribution\n", - " noise = tf.random_normal([BATCH_SIZE, noise_dim])\n", - " \n", - " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", - " generated_images = generator(noise, training=True)\n", - " \n", - " real_output = discriminator(images, training=True)\n", - " generated_output = discriminator(generated_images, training=True)\n", - " \n", - " gen_loss = generator_loss(generated_output)\n", - " disc_loss = discriminator_loss(real_output, generated_output)\n", - " \n", - " gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)\n", - " gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)\n", - " \n", - " generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))\n", - " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "6TSZgwc2BUQ-" - }, - "cell_type": "markdown", - "source": [ - "\n", - "This model takes about ~30 seconds per epoch to train on a single Tesla K80 on Colab, as of October 2018. \n", - "\n", - "Eager execution can be slower than executing the equivalent graph as it can't benefit from whole-program optimizations on the graph, and also incurs overheads of interpreting Python code. By using [tf.contrib.eager.defun](https://www.tensorflow.org/api_docs/python/tf/contrib/eager/defun) to create graph functions, we get a ~20 secs/epoch performance boost (from ~50 secs/epoch down to ~30 secs/epoch). This way we get the best of both eager execution (easier for debugging) and graph mode (better performance)." - ] - }, - { - "metadata": { - "id": "Iwya07_j5p2A", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "train_step = tf.contrib.eager.defun(train_step)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "2M7LmLtGEMQJ", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def train(dataset, epochs): \n", - " for epoch in range(epochs):\n", - " start = time.time()\n", - " \n", - " for images in dataset:\n", - " train_step(images)\n", - "\n", - " display.clear_output(wait=True)\n", - " generate_and_save_images(generator,\n", - " epoch + 1,\n", - " random_vector_for_generation)\n", - " \n", - " # saving (checkpoint) the model every 15 epochs\n", - " if (epoch + 1) % 15 == 0:\n", - " checkpoint.save(file_prefix = checkpoint_prefix)\n", - " \n", - " print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n", - " time.time()-start))\n", - " # generating after the final epoch\n", - " display.clear_output(wait=True)\n", - " generate_and_save_images(generator,\n", - " epochs,\n", - " random_vector_for_generation)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "2aFF7Hk3XdeW" - }, - "cell_type": "markdown", - "source": [ - "**Generate and save images**\n", - "\n" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "RmdVsmvhPxyy", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def generate_and_save_images(model, epoch, test_input):\n", - " # make sure the training parameter is set to False because we\n", - " # don't want to train the batchnorm layer when doing inference.\n", - " predictions = model(test_input, training=False)\n", - "\n", - " fig = plt.figure(figsize=(4,4))\n", - " \n", - " for i in range(predictions.shape[0]):\n", - " plt.subplot(4, 4, i+1)\n", - " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", - " plt.axis('off')\n", - " \n", - " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", - " plt.show()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "dZrd4CdjR-Fp" - }, - "cell_type": "markdown", - "source": [ - "## Train the GANs\n", - "We will call the train() method defined above to train the generator and discriminator simultaneously. Note, training GANs can be tricky. It's important that the generator and discriminator do not overpower each other (e.g., that they train at a similar rate).\n", - "\n", - "At the beginning of the training, the generated images look like random noise. As training progresses, you can see the generated digits look increasingly real. After 50 epochs, they look very much like the MNIST digits." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "Ly3UN0SLLY2l", - "colab": {} - }, - "cell_type": "code", - "source": [ - "%%time\n", - "train(train_dataset, EPOCHS)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "rfM4YcPVPkNO" - }, - "cell_type": "markdown", - "source": [ - "**Restore the latest checkpoint**" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "XhXsd0srPo8c", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# restoring the latest checkpoint in checkpoint_dir\n", - "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "P4M_vIbUi7c0" - }, - "cell_type": "markdown", - "source": [ - "## Generated images \n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "mLskt7EfXAjr" - }, - "cell_type": "markdown", - "source": [ - "\n", - "After training, its time to generate some images! \n", - "The last step is to plot the generated images and voila!\n" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "WfO5wCdclHGL", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# Display a single image using the epoch number\n", - "def display_image(epoch_no):\n", - " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "5x3q9_Oe5q0A", - "colab": {} - }, - "cell_type": "code", - "source": [ - "display_image(EPOCHS)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "NywiH3nL8guF" - }, - "cell_type": "markdown", - "source": [ - "**Generate a GIF of all the saved images**\n", - "\n", - "We will use imageio to create an animated gif using all the images saved during training." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "IGKQgENQ8lEI", - "colab": {} - }, - "cell_type": "code", - "source": [ - "with imageio.get_writer('dcgan.gif', mode='I') as writer:\n", - " filenames = glob.glob('image*.png')\n", - " filenames = sorted(filenames)\n", - " last = -1\n", - " for i,filename in enumerate(filenames):\n", - " frame = 2*(i**0.5)\n", - " if round(frame) > round(last):\n", - " last = frame\n", - " else:\n", - " continue\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " \n", - "# this is a hack to display the gif inside the notebook\n", - "os.system('cp dcgan.gif dcgan.gif.png')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "cGhC3-fMWSwl" - }, - "cell_type": "markdown", - "source": [ - "Display the animated gif with all the mages generated during the training of GANs." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "uV0yiKpzNP1b", - "colab": {} - }, - "cell_type": "code", - "source": [ - "display.Image(filename=\"dcgan.gif.png\")" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "6EEG-wePkmJQ" - }, - "cell_type": "markdown", - "source": [ - "**Download the animated gif**\n", - "\n", - "Uncomment the code below to download an animated gif from Colab." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "4UJjSnIMOzOJ", - "colab": {} - }, - "cell_type": "code", - "source": [ - "#from google.colab import files\n", - "#files.download('dcgan.gif')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "k6qC-SbjK0yW" - }, - "cell_type": "markdown", - "source": [ - "## Learn more about GANs\n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "xjjkT9KAK6H7" - }, - "cell_type": "markdown", - "source": [ - "We hope this tutorial was helpful! As a next step, you might like to experiment with a different dataset, for example the Large-scale Celeb Faces Attributes (CelebA) dataset [available on Kaggle](https://www.kaggle.com/jessicali9530/celeba-dataset/home).\n", - "\n", - "To learn more about GANs:\n", - "\n", - "* Check out MIT's lecture (linked above), or [this](http://cs231n.stanford.edu/slides/2018/cs231n_2018_lecture12.pdf) lecture form Stanford's CS231n. \n", - "\n", - "* We also recommend the [CVPR 2018 Tutorial on GANs](https://sites.google.com/view/cvpr2018tutorialongans/), and the [NIPS 2016 Tutorial: Generative Adversarial Networks](https://arxiv.org/abs/1701.00160).\n" - ] + "kernelspec": { + "display_name": "Python 2", + "name": "python2" } - ] -} \ No newline at end of file + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png b/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png deleted file mode 100644 index b715bd83ef117641c6429e0ac173dbe9b8d5fd88..0000000000000000000000000000000000000000 Binary files a/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png and /dev/null differ diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index 12c5eff2b4aa901bdab52bf545e95b1e4dce7468..979772acd3f823a8cc53ab5e026946ad3bb19353 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -1,1174 +1,71 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "K2s1A9eLRPEj" - }, - "source": [ - "##### Copyright 2018 The TensorFlow Authors.\n", - "\n", - "Licensed under the Apache License, Version 2.0 (the \"License\").\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Cffg2i257iMS" - }, - "source": [ - "# Image Captioning with Attention\n", - "\n", - "
\n", - "\n", - " Run in Google Colab \n", - "\n", - "View source on GitHub
" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "QASbY_HGo4Lq" - }, - "source": [ - "Image captioning is the task of generating a caption for an image. Given an image like this:\n", - "\n", - "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", - "\n", - "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", - "\n", - "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", - "\n", - "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", - "\n", - "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n", - "\n", - "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n", - "\n", - "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n", - "\n", - "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n", - "\n", - "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "U8l4RJ0XRPEm" - }, - "outputs": [], - "source": [ - "# Import TensorFlow and enable eager execution\n", - "# This code requires TensorFlow version >=1.9\n", - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "# We'll generate plots of attention in order to see which parts of an image\n", - "# our model focuses on during captioning\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# Scikit-learn includes many helpful utilities\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.utils import shuffle\n", - "\n", - "import re\n", - "import numpy as np\n", - "import os\n", - "import time\n", - "import json\n", - "from glob import glob\n", - "from PIL import Image\n", - "import pickle" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "b6qbGw8MRPE5" - }, - "source": [ - "## Download and prepare the MS-COCO dataset\n", - "\n", - "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n", - "\n", - "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "krQuPYTtRPE7" - }, - "outputs": [], - "source": [ - "annotation_zip = tf.keras.utils.get_file('captions.zip', \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", - " extract = True)\n", - "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", - "\n", - "name_of_zip = 'train2014.zip'\n", - "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n", - " image_zip = tf.keras.utils.get_file(name_of_zip, \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", - " extract = True)\n", - " PATH = os.path.dirname(image_zip)+'/train2014/'\n", - "else:\n", - " PATH = os.path.abspath('.')+'/train2014/'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "aANEzb5WwSzg" - }, - "source": [ - "## Optionally, limit the size of the training set for faster training\n", - "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "4G3b8x8_RPFD" - }, - "outputs": [], - "source": [ - "# read the json file\n", - "with open(annotation_file, 'r') as f:\n", - " annotations = json.load(f)\n", - "\n", - "# storing the captions and the image name in vectors\n", - "all_captions = []\n", - "all_img_name_vector = []\n", - "\n", - "for annot in annotations['annotations']:\n", - " caption = ' ' + annot['caption'] + ' '\n", - " image_id = annot['image_id']\n", - " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", - " \n", - " all_img_name_vector.append(full_coco_image_path)\n", - " all_captions.append(caption)\n", - "\n", - "# shuffling the captions and image_names together\n", - "# setting a random state\n", - "train_captions, img_name_vector = shuffle(all_captions,\n", - " all_img_name_vector,\n", - " random_state=1)\n", - "\n", - "# selecting the first 30000 captions from the shuffled set\n", - "num_examples = 30000\n", - "train_captions = train_captions[:num_examples]\n", - "img_name_vector = img_name_vector[:num_examples]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "mPBMgK34RPFL" - }, - "outputs": [], - "source": [ - "len(train_captions), len(all_captions)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "8cSW4u-ORPFQ" - }, - "source": [ - "## Preprocess the images using InceptionV3\n", - "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n", - "\n", - "First, we will need to convert the images into the format inceptionV3 expects by:\n", - "* Resizing the image to (299, 299)\n", - "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "zXR0217aRPFR" - }, - "outputs": [], - "source": [ - "def load_image(image_path):\n", - " img = tf.read_file(image_path)\n", - " img = tf.image.decode_jpeg(img, channels=3)\n", - " img = tf.image.resize_images(img, (299, 299))\n", - " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", - " return img, image_path" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "MDvIu4sXRPFV" - }, - "source": [ - "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", - "\n", - "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n", - "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n", - "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n", - "* We avoid doing this during training so it does not become a bottleneck. \n", - "* After all the images are passed through the network, we pickle the dictionary and save it to disk." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "RD3vW4SsRPFW" - }, - "outputs": [], - "source": [ - "image_model = tf.keras.applications.InceptionV3(include_top=False, \n", - " weights='imagenet')\n", - "new_input = image_model.input\n", - "hidden_layer = image_model.layers[-1].output\n", - "\n", - "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "rERqlR3WRPGO" - }, - "source": [ - "## Caching the features extracted from InceptionV3\n", - "\n", - "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n", - "\n", - "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n", - "\n", - "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n", - "\n", - "```for img, path in image_dataset:``` \n", - "\n", - "to:\n", - "\n", - "```for img, path in tqdm(image_dataset):```." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "Dx_fvbVgRPGQ" - }, - "outputs": [], - "source": [ - "# getting the unique images\n", - "encode_train = sorted(set(img_name_vector))\n", - "\n", - "# feel free to change the batch_size according to your system configuration\n", - "image_dataset = tf.data.Dataset.from_tensor_slices(\n", - " encode_train).map(load_image).batch(16)\n", - "\n", - "for img, path in image_dataset:\n", - " batch_features = image_features_extract_model(img)\n", - " batch_features = tf.reshape(batch_features, \n", - " (batch_features.shape[0], -1, batch_features.shape[3]))\n", - "\n", - " for bf, p in zip(batch_features, path):\n", - " path_of_feature = p.numpy().decode(\"utf-8\")\n", - " np.save(path_of_feature, bf.numpy())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "nyqH3zFwRPFi" - }, - "source": [ - "## Preprocess and tokenize the captions\n", - "\n", - "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n", - "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n", - "* Finally, we create a word --> index mapping and vice-versa.\n", - "* We will then pad all sequences to the be same length as the longest one. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "HZfK8RhQRPFj" - }, - "outputs": [], - "source": [ - "# This will find the maximum length of any caption in our dataset\n", - "def calc_max_length(tensor):\n", - " return max(len(t) for t in tensor)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "oJGE34aiRPFo" - }, - "outputs": [], - "source": [ - "# The steps above is a general process of dealing with text processing\n", - "\n", - "# choosing the top 5000 words from the vocabulary\n", - "top_k = 5000\n", - "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n", - " oov_token=\"\", \n", - " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", - "tokenizer.fit_on_texts(train_captions)\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "8Q44tNQVRPFt" - }, - "outputs": [], - "source": [ - "tokenizer.word_index[''] = 0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "0fpJb5ojRPFv" - }, - "outputs": [], - "source": [ - "# creating the tokenized vectors\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "AidglIZVRPF4" - }, - "outputs": [], - "source": [ - "# padding each vector to the max_length of the captions\n", - "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n", - "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "gL0wkttkRPGA" - }, - "outputs": [], - "source": [ - "# calculating the max_length \n", - "# used to store the attention weights\n", - "max_length = calc_max_length(train_seqs)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "M3CD75nDpvTI" - }, - "source": [ - "## Split the data into training and testing" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "iS7DDMszRPGF" - }, - "outputs": [], - "source": [ - "# Create training and validation sets using 80-20 split\n", - "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n", - " cap_vector, \n", - " test_size=0.2, \n", - " random_state=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "XmViPkRFRPGH" - }, - "outputs": [], - "source": [ - "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "uEWM9xrYcg45" - }, - "source": [ - "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "Q3TnZ1ToRPGV" - }, - "outputs": [], - "source": [ - "# feel free to change these parameters according to your system's configuration\n", - "\n", - "BATCH_SIZE = 64\n", - "BUFFER_SIZE = 1000\n", - "embedding_dim = 256\n", - "units = 512\n", - "vocab_size = len(tokenizer.word_index)\n", - "# shape of the vector extracted from InceptionV3 is (64, 2048)\n", - "# these two variables represent that\n", - "features_shape = 2048\n", - "attention_features_shape = 64" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "SmZS2N0bXG3T" - }, - "outputs": [], - "source": [ - "# loading the numpy files \n", - "def map_func(img_name, cap):\n", - " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", - " return img_tensor, cap" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "FDF_Nm3tRPGZ" - }, - "outputs": [], - "source": [ - "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", - "\n", - "# using map to load the numpy files in parallel\n", - "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n", - "# https://www.tensorflow.org/api_docs/python/tf/py_func\n", - "dataset = dataset.map(lambda item1, item2: tf.py_func(\n", - " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n", - "\n", - "# shuffling and batching\n", - "dataset = dataset.shuffle(BUFFER_SIZE)\n", - "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n", - "dataset = dataset.batch(BATCH_SIZE)\n", - "dataset = dataset.prefetch(1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "nrvoDphgRPGd" - }, - "source": [ - "## Model\n", - "\n", - "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", - "\n", - "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", - "\n", - "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n", - "* We squash that to a shape of (64, 2048).\n", - "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n", - "* The RNN(here GRU) attends over the image to predict the next word." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "AAppCGLKRPGd" - }, - "outputs": [], - "source": [ - "def gru(units):\n", - " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n", - " # significant speedup).\n", - " if tf.test.is_gpu_available():\n", - " return tf.keras.layers.CuDNNGRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_initializer='glorot_uniform')\n", - " else:\n", - " return tf.keras.layers.GRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_activation='sigmoid', \n", - " recurrent_initializer='glorot_uniform')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "ja2LFTMSdeV3" - }, - "outputs": [], - "source": [ - "class BahdanauAttention(tf.keras.Model):\n", - " def __init__(self, units):\n", - " super(BahdanauAttention, self).__init__()\n", - " self.W1 = tf.keras.layers.Dense(units)\n", - " self.W2 = tf.keras.layers.Dense(units)\n", - " self.V = tf.keras.layers.Dense(1)\n", - " \n", - " def call(self, features, hidden):\n", - " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", - " \n", - " # hidden shape == (batch_size, hidden_size)\n", - " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", - " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", - " \n", - " # score shape == (batch_size, 64, hidden_size)\n", - " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", - " \n", - " # attention_weights shape == (batch_size, 64, 1)\n", - " # we get 1 at the last axis because we are applying score to self.V\n", - " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", - " \n", - " # context_vector shape after sum == (batch_size, hidden_size)\n", - " context_vector = attention_weights * features\n", - " context_vector = tf.reduce_sum(context_vector, axis=1)\n", - " \n", - " return context_vector, attention_weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "AZ7R1RxHRPGf" - }, - "outputs": [], - "source": [ - "class CNN_Encoder(tf.keras.Model):\n", - " # Since we have already extracted the features and dumped it using pickle\n", - " # This encoder passes those features through a Fully connected layer\n", - " def __init__(self, embedding_dim):\n", - " super(CNN_Encoder, self).__init__()\n", - " # shape after fc == (batch_size, 64, embedding_dim)\n", - " self.fc = tf.keras.layers.Dense(embedding_dim)\n", - " \n", - " def call(self, x):\n", - " x = self.fc(x)\n", - " x = tf.nn.relu(x)\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "V9UbGQmERPGi" - }, - "outputs": [], - "source": [ - "class RNN_Decoder(tf.keras.Model):\n", - " def __init__(self, embedding_dim, units, vocab_size):\n", - " super(RNN_Decoder, self).__init__()\n", - " self.units = units\n", - "\n", - " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", - " self.gru = gru(self.units)\n", - " self.fc1 = tf.keras.layers.Dense(self.units)\n", - " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", - " \n", - " self.attention = BahdanauAttention(self.units)\n", - " \n", - " def call(self, x, features, hidden):\n", - " # defining attention as a separate model\n", - " context_vector, attention_weights = self.attention(features, hidden)\n", - " \n", - " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", - " x = self.embedding(x)\n", - " \n", - " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", - " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", - " \n", - " # passing the concatenated vector to the GRU\n", - " output, state = self.gru(x)\n", - " \n", - " # shape == (batch_size, max_length, hidden_size)\n", - " x = self.fc1(output)\n", - " \n", - " # x shape == (batch_size * max_length, hidden_size)\n", - " x = tf.reshape(x, (-1, x.shape[2]))\n", - " \n", - " # output shape == (batch_size * max_length, vocab)\n", - " x = self.fc2(x)\n", - "\n", - " return x, state, attention_weights\n", - "\n", - " def reset_state(self, batch_size):\n", - " return tf.zeros((batch_size, self.units))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "Qs_Sr03wRPGk" - }, - "outputs": [], - "source": [ - "encoder = CNN_Encoder(embedding_dim)\n", - "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "-bYN7xA0RPGl" - }, - "outputs": [], - "source": [ - "optimizer = tf.train.AdamOptimizer()\n", - "\n", - "# We are masking the loss calculated for padding\n", - "def loss_function(real, pred):\n", - " mask = 1 - np.equal(real, 0)\n", - " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", - " return tf.reduce_mean(loss_)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "PHod7t72RPGn" - }, - "source": [ - "## Training\n", - "\n", - "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", - "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", - "* The decoder returns the predictions and the decoder hidden state.\n", - "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", - "* Use teacher forcing to decide the next input to the decoder.\n", - "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", - "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "Vt4WZ5mhJE-E" - }, - "outputs": [], - "source": [ - "# adding this in a separate cell because if you run the training cell \n", - "# many times, the loss_plot array will be reset\n", - "loss_plot = []" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "UlA4VIQpRPGo" - }, - "outputs": [], - "source": [ - "EPOCHS = 20\n", - "\n", - "for epoch in range(EPOCHS):\n", - " start = time.time()\n", - " total_loss = 0\n", - " \n", - " for (batch, (img_tensor, target)) in enumerate(dataset):\n", - " loss = 0\n", - " \n", - " # initializing the hidden state for each batch\n", - " # because the captions are not related from image to image\n", - " hidden = decoder.reset_state(batch_size=target.shape[0])\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']] * BATCH_SIZE, 1)\n", - " \n", - " with tf.GradientTape() as tape:\n", - " features = encoder(img_tensor)\n", - " \n", - " for i in range(1, target.shape[1]):\n", - " # passing the features through the decoder\n", - " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", - "\n", - " loss += loss_function(target[:, i], predictions)\n", - " \n", - " # using teacher forcing\n", - " dec_input = tf.expand_dims(target[:, i], 1)\n", - " \n", - " total_loss += (loss / int(target.shape[1]))\n", - " \n", - " variables = encoder.variables + decoder.variables\n", - " \n", - " gradients = tape.gradient(loss, variables) \n", - " \n", - " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", - " \n", - " if batch % 100 == 0:\n", - " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n", - " batch, \n", - " loss.numpy() / int(target.shape[1])))\n", - " # storing the epoch end loss value to plot later\n", - " loss_plot.append(total_loss / len(cap_vector))\n", - " \n", - " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n", - " total_loss/len(cap_vector)))\n", - " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "1Wm83G-ZBPcC" - }, - "outputs": [], - "source": [ - "plt.plot(loss_plot)\n", - "plt.xlabel('Epochs')\n", - "plt.ylabel('Loss')\n", - "plt.title('Loss Plot')\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "xGvOcLQKghXN" - }, - "source": [ - "## Caption!\n", - "\n", - "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", - "* Stop predicting when the model predicts the end token.\n", - "* And store the attention weights for every time step." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "RCWpDtyNRPGs" - }, - "outputs": [], - "source": [ - "def evaluate(image):\n", - " attention_plot = np.zeros((max_length, attention_features_shape))\n", - "\n", - " hidden = decoder.reset_state(batch_size=1)\n", - "\n", - " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", - " img_tensor_val = image_features_extract_model(temp_input)\n", - " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", - "\n", - " features = encoder(img_tensor_val)\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", - " result = []\n", - "\n", - " for i in range(max_length):\n", - " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", - "\n", - " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", - "\n", - " predicted_id = tf.argmax(predictions[0]).numpy()\n", - " result.append(tokenizer.index_word[predicted_id])\n", - "\n", - " if tokenizer.index_word[predicted_id] == '':\n", - " return result, attention_plot\n", - "\n", - " dec_input = tf.expand_dims([predicted_id], 0)\n", - "\n", - " attention_plot = attention_plot[:len(result), :]\n", - " return result, attention_plot" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "fD_y7PD6RPGt" - }, - "outputs": [], - "source": [ - "def plot_attention(image, result, attention_plot):\n", - " temp_image = np.array(Image.open(image))\n", - "\n", - " fig = plt.figure(figsize=(10, 10))\n", - " \n", - " len_result = len(result)\n", - " for l in range(len_result):\n", - " temp_att = np.resize(attention_plot[l], (8, 8))\n", - " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", - " ax.set_title(result[l])\n", - " img = ax.imshow(temp_image)\n", - " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", - "\n", - " plt.tight_layout()\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "K2s1A9eLRPEj" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n" + ] }, - "colab_type": "code", - "id": "io7ws3ReRPGv" - }, - "outputs": [], - "source": [ - "# captions on the validation set\n", - "rid = np.random.randint(0, len(img_name_val))\n", - "image = img_name_val[rid]\n", - "real_caption = ' '.join([tokenizer.index_word[i] for i in cap_val[rid] if i not in [0]])\n", - "result, attention_plot = evaluate(image)\n", - "\n", - "print ('Real Caption:', real_caption)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image, result, attention_plot)\n", - "# opening the image\n", - "Image.open(img_name_val[rid])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Rprk3HEvZuxb" - }, - "source": [ - "## Try it on your own images\n", - "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Cffg2i257iMS" + }, + "source": [ + "# Image Captioning with Attention\n", + "\n", + "This example has moved:\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/image_captioning.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/image_captioning.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] }, - "colab_type": "code", - "id": "9Psd1quzaAWg" - }, - "outputs": [], - "source": [ - "image_url = 'https://tensorflow.org/images/surf.jpg'\n", - "image_extension = image_url[-4:]\n", - "image_path = tf.keras.utils.get_file('image'+image_extension, \n", - " origin=image_url)\n", - "\n", - "result, attention_plot = evaluate(image_path)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image_path, result, attention_plot)\n", - "# opening the image\n", - "Image.open(image_path)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "VJZXyJco6uLO" - }, - "source": [ - "# Next steps\n", - "\n", - "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "image_captioning_with_attention.ipynb", - "private_outputs": true, - "provenance": [ { - "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", - "timestamp": 1530222436922 + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QASbY_HGo4Lq" + }, + "source": [ + "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", + "\n", + "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", + "\n", + "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "image_captioning_with_attention.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", + "timestamp": 1530222436922 + } + ], + "toc_visible": true, + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" } - ], - "toc_visible": true, - "version": "0.3.2", - "views": {} - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index bda9e77085e45ae31a228142135425e22a1c6780..c945c753b3ba36d16aa6985d23a5849f8f552304 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -13,633 +13,13 @@ "\n", "# Text Generation using a RNN\n", "\n", + "This example has moved.\n", + "\n", "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/sequences/text_generation.ipynb\"\u003e\n", " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on Github\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "BwpJ5IffzRG6" - }, - "source": [ - "This notebook demonstrates how to generate text using an RNN using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). If you like, you can write a similar [model](https://github.com/fchollet/deep-learning-with-python-notebooks/blob/master/8.1-text-generation-with-lstm.ipynb) using less code. Here, we show a lower-level impementation that's useful to understand as prework before diving in to deeper examples in a similar, like [Neural Machine Translation with Attention](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", - "\n", - "This notebook is an end-to-end example. When you run it, it will download a dataset of Shakespeare's writing. We'll use a collection of plays, borrowed from Andrej Karpathy's excellent [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). The notebook will train a model, and use it to generate sample output.\n", - " \n", - "Here is the output(with start string='w') after training a single layer GRU for 30 epochs with the default settings below:\n", - "\n", - "```\n", - "were to the death of him\n", - "And nothing of the field in the view of hell,\n", - "When I said, banish him, I will not burn thee that would live.\n", - "\n", - "HENRY BOLINGBROKE:\n", - "My gracious uncle--\n", - "\n", - "DUKE OF YORK:\n", - "As much disgraced to the court, the gods them speak,\n", - "And now in peace himself excuse thee in the world.\n", - "\n", - "HORTENSIO:\n", - "Madam, 'tis not the cause of the counterfeit of the earth,\n", - "And leave me to the sun that set them on the earth\n", - "And leave the world and are revenged for thee.\n", - "\n", - "GLOUCESTER:\n", - "I would they were talking with the very name of means\n", - "To make a puppet of a guest, and therefore, good Grumio,\n", - "Nor arm'd to prison, o' the clouds, of the whole field,\n", - "With the admire\n", - "With the feeding of thy chair, and we have heard it so,\n", - "I thank you, sir, he is a visor friendship with your silly your bed.\n", - "\n", - "SAMPSON:\n", - "I do desire to live, I pray: some stand of the minds, make thee remedies\n", - "With the enemies of my soul.\n", - "\n", - "MENENIUS:\n", - "I'll keep the cause of my mistress.\n", - "\n", - "POLIXENES:\n", - "My brother Marcius!\n", - "\n", - "Second Servant:\n", - "Will't ple\n", - "```\n", - "\n", - "Of course, while some of the sentences are grammatical, most do not make sense. But, consider:\n", - "\n", - "* Our model is character based (when we began training, it did not yet know how to spell a valid English word, or that words were even a unit of text).\n", - "\n", - "* The structure of the output resembles a play (blocks begin with a speaker name, in all caps similar to the original text). Sentences generally end with a period. If you look at the text from a distance (or don't read the invididual words too closely, it appears as if it's an excerpt from a play).\n", - "\n", - "As a next step, you can experiment training the model on a different dataset - any large text file(ASCII) will do, and you can modify a single line of code below to make that change. Have fun!\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "R3p22DBDsaCA" - }, - "source": [ - "## Install unidecode library\n", - "A helpful library to convert unicode to ASCII." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "wZ6LOM12wKGH" - }, - "outputs": [], - "source": [ - "!pip install unidecode" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "WGyKZj3bzf9p" - }, - "source": [ - "## Import tensorflow and enable eager execution." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "yG_n40gFzf9s" - }, - "outputs": [], - "source": [ - "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", - "import tensorflow as tf\n", - "\n", - "# Note: Once you enable eager execution, it cannot be disabled. \n", - "tf.enable_eager_execution()\n", - "\n", - "import numpy as np\n", - "import os\n", - "import re\n", - "import random\n", - "import unidecode\n", - "import time" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "EHDoRoc5PKWz" - }, - "source": [ - "## Download the dataset\n", - "\n", - "In this example, we will use the [shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt). You can use any other dataset that you like.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "pD_55cOxLkAb" - }, - "outputs": [], - "source": [ - "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "UHjdCjDuSvX_" - }, - "source": [ - "## Read the dataset\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "-E5JvY3wzf94" - }, - "outputs": [], - "source": [ - "text = unidecode.unidecode(open(path_to_file).read())\n", - "# length of text is the number of characters in it\n", - "print (len(text))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Il9ww98izf-D" - }, - "source": [ - "Creating dictionaries to map from characters to their indices and vice-versa, which will be used to vectorize the inputs" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "IalZLbvOzf-F" - }, - "outputs": [], - "source": [ - "# unique contains all the unique characters in the file\n", - "unique = sorted(set(text))\n", - "\n", - "# creating a mapping from unique characters to indices\n", - "char2idx = {u:i for i, u in enumerate(unique)}\n", - "idx2char = {i:u for i, u in enumerate(unique)}" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "1v_qUYfAzf-I" - }, - "outputs": [], - "source": [ - "# setting the maximum length sentence we want for a single input in characters\n", - "max_length = 100\n", - "\n", - "# length of the vocabulary in chars\n", - "vocab_size = len(unique)\n", - "\n", - "# the embedding dimension \n", - "embedding_dim = 256\n", - "\n", - "# number of RNN (here GRU) units\n", - "units = 1024\n", - "\n", - "# batch size \n", - "BATCH_SIZE = 64\n", - "\n", - "# buffer size to shuffle our dataset\n", - "BUFFER_SIZE = 10000" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "LFjSVAlWzf-N" - }, - "source": [ - "## Creating the input and output tensors\n", - "\n", - "Vectorizing the input and the target text because our model cannot understand strings only numbers.\n", - "\n", - "But first, we need to create the input and output vectors.\n", - "Remember the max_length we set above, we will use it here. We are creating **max_length** chunks of input, where each input vector is all the characters in that chunk except the last and the target vector is all the characters in that chunk except the first.\n", - "\n", - "For example, consider that the string = 'tensorflow' and the max_length is 9\n", - "\n", - "So, the `input = 'tensorflo'` and `output = 'ensorflow'`\n", - "\n", - "After creating the vectors, we convert each character into numbers using the **char2idx** dictionary we created above." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "0UHJDA39zf-O" - }, - "outputs": [], - "source": [ - "input_text = []\n", - "target_text = []\n", - "\n", - "for f in range(0, len(text)-max_length, max_length):\n", - " inps = text[f:f+max_length]\n", - " targ = text[f+1:f+1+max_length]\n", - "\n", - " input_text.append([char2idx[i] for i in inps])\n", - " target_text.append([char2idx[t] for t in targ])\n", - " \n", - "print (np.array(input_text).shape)\n", - "print (np.array(target_text).shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "MJdfPmdqzf-R" - }, - "source": [ - "## Creating batches and shuffling them using tf.data" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "p2pGotuNzf-S" - }, - "outputs": [], - "source": [ - "dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n", - "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "m8gPwEjRzf-Z" - }, - "source": [ - "## Creating the model\n", - "\n", - "We use the Model Subclassing API which gives us full flexibility to create the model and change it however we like. We use 3 layers to define our model.\n", - "\n", - "* Embedding layer\n", - "* GRU layer (you can use an LSTM layer here)\n", - "* Fully connected layer" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "P3KTiiInzf-a" - }, - "outputs": [], - "source": [ - "class Model(tf.keras.Model):\n", - " def __init__(self, vocab_size, embedding_dim, units, batch_size):\n", - " super(Model, self).__init__()\n", - " self.units = units\n", - " self.batch_sz = batch_size\n", - "\n", - " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", - "\n", - " if tf.test.is_gpu_available():\n", - " self.gru = tf.keras.layers.CuDNNGRU(self.units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_initializer='glorot_uniform')\n", - " else:\n", - " self.gru = tf.keras.layers.GRU(self.units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_activation='sigmoid', \n", - " recurrent_initializer='glorot_uniform')\n", - "\n", - " self.fc = tf.keras.layers.Dense(vocab_size)\n", - " \n", - " def call(self, x, hidden):\n", - " x = self.embedding(x)\n", - "\n", - " # output shape == (batch_size, max_length, hidden_size) \n", - " # states shape == (batch_size, hidden_size)\n", - "\n", - " # states variable to preserve the state of the model\n", - " # this will be used to pass at every step to the model while training\n", - " output, states = self.gru(x, initial_state=hidden)\n", - "\n", - "\n", - " # reshaping the output so that we can pass it to the Dense layer\n", - " # after reshaping the shape is (batch_size * max_length, hidden_size)\n", - " output = tf.reshape(output, (-1, output.shape[2]))\n", - "\n", - " # The dense layer will output predictions for every time_steps(max_length)\n", - " # output shape after the dense layer == (max_length * batch_size, vocab_size)\n", - " x = self.fc(output)\n", - "\n", - " return x, states" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "trpqTWyvk0nr" - }, - "source": [ - "## Call the model and set the optimizer and the loss function" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "7t2XrzEOzf-e" - }, - "outputs": [], - "source": [ - "model = Model(vocab_size, embedding_dim, units, BATCH_SIZE)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "dkjWIATszf-h" - }, - "outputs": [], - "source": [ - "optimizer = tf.train.AdamOptimizer()\n", - "\n", - "# using sparse_softmax_cross_entropy so that we don't have to create one-hot vectors\n", - "def loss_function(real, preds):\n", - " return tf.losses.sparse_softmax_cross_entropy(labels=real, logits=preds)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "3K6s6F79P7za" - }, - "source": [ - "## Checkpoints (Object-based saving)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "oAGisDdfP9rL" - }, - "outputs": [], - "source": [ - "checkpoint_dir = './training_checkpoints'\n", - "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", - "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", - " model=model)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "lPrP0XMUzf-p" - }, - "source": [ - "## Train the model\n", - "\n", - "Here we will use a custom training loop with the help of GradientTape()\n", - "\n", - "* We initialize the hidden state of the model with zeros and shape == (batch_size, number of rnn units). We do this by calling the function defined while creating the model.\n", - "\n", - "* Next, we iterate over the dataset(batch by batch) and calculate the **predictions and the hidden states** associated with that input.\n", - "\n", - "* There are a lot of interesting things happening here.\n", - " * The model gets hidden state(initialized with 0), lets call that **H0** and the first batch of input, lets call that **I0**.\n", - " * The model then returns the predictions **P1** and **H1**.\n", - " * For the next batch of input, the model receives **I1** and **H1**.\n", - " * The interesting thing here is that we pass **H1** to the model with **I1** which is how the model learns. The context learned from batch to batch is contained in the **hidden state**.\n", - " * We continue doing this until the dataset is exhausted and then we start a new epoch and repeat this.\n", - "\n", - "* After calculating the predictions, we calculate the **loss** using the loss function defined above. Then we calculate the gradients of the loss with respect to the model variables(input)\n", - "\n", - "* Finally, we take a step in that direction with the help of the optimizer using the apply_gradients function.\n", - "\n", - "Note:- If you are running this notebook in Colab which has a **Tesla K80 GPU** it takes about 23 seconds per epoch.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "d4tSNwymzf-q" - }, - "outputs": [], - "source": [ - "# Training step\n", - "\n", - "EPOCHS = 20\n", - "\n", - "for epoch in range(EPOCHS):\n", - " start = time.time()\n", - " \n", - " # initializing the hidden state at the start of every epoch\n", - " hidden = model.reset_states()\n", - " \n", - " for (batch, (inp, target)) in enumerate(dataset):\n", - " with tf.GradientTape() as tape:\n", - " # feeding the hidden state back into the model\n", - " # This is the interesting step\n", - " predictions, hidden = model(inp, hidden)\n", - " \n", - " # reshaping the target because that's how the \n", - " # loss function expects it\n", - " target = tf.reshape(target, (-1,))\n", - " loss = loss_function(target, predictions)\n", - " \n", - " grads = tape.gradient(loss, model.variables)\n", - " optimizer.apply_gradients(zip(grads, model.variables))\n", - "\n", - " if batch % 100 == 0:\n", - " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1,\n", - " batch,\n", - " loss))\n", - " # saving (checkpoint) the model every 5 epochs\n", - " if (epoch + 1) % 5 == 0:\n", - " checkpoint.save(file_prefix = checkpoint_prefix)\n", - "\n", - " print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n", - " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "01AR9vpNQMFF" - }, - "source": [ - "## Restore the latest checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "tyvpYomYQQkF" - }, - "outputs": [], - "source": [ - "# restoring the latest checkpoint in checkpoint_dir\n", - "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "DjGz1tDkzf-u" - }, - "source": [ - "## Predicting using our trained model\n", - "\n", - "The below code block is used to generated the text\n", - "\n", - "* We start by choosing a start string and initializing the hidden state and setting the number of characters we want to generate.\n", - "\n", - "* We get predictions using the start_string and the hidden state\n", - "\n", - "* Then we use argmax to calculate the index of the predicted word. **We use this predicted word as our next input to the model**\n", - "\n", - "* **The hidden state returned by the model is fed back into the model so that it now has more context rather than just one word.** After we predict the next word, the modified hidden states are again fed back into the model, which is how it learns as it gets more context from the previously predicted words.\n", - "\n", - "* If you see the predictions, the model knows when to capitalize, make paragraphs and the text follows a shakespeare style of writing which is pretty awesome!" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "WvuwZBX5Ogfd" - }, - "outputs": [], - "source": [ - "# Evaluation step(generating text using the model learned)\n", - "\n", - "# number of characters to generate\n", - "num_generate = 1000\n", - "\n", - "# You can change the start string to experiment\n", - "start_string = 'Q'\n", - "# converting our start string to numbers(vectorizing!) \n", - "input_eval = [char2idx[s] for s in start_string]\n", - "input_eval = tf.expand_dims(input_eval, 0)\n", - "\n", - "# empty string to store our results\n", - "text_generated = ''\n", - "\n", - "# hidden state shape == (batch_size, number of rnn units); here batch size == 1\n", - "hidden = [tf.zeros((1, units))]\n", - "for i in range(num_generate):\n", - " predictions, hidden = model(input_eval, hidden)\n", - "\n", - " # using argmax to predict the word returned by the model\n", - " predicted_id = tf.argmax(predictions[-1]).numpy()\n", - " \n", - " # We pass the predicted word as the next input to the model\n", - " # along with the previous hidden state\n", - " input_eval = tf.expand_dims([predicted_id], 0)\n", - " \n", - " text_generated += idx2char[predicted_id]\n", - "\n", - "print (start_string + text_generated)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "AM2Uma_-yVIq" - }, - "source": [ - "## Next steps\n", - "\n", - "* Change the start string to a different character, or the start of a sentence.\n", - "* Experiment with training on a different, or with different parameters. [Project Gutenberg](http://www.gutenberg.org/ebooks/100), for example, contains a large collection of books.\n", - "* Add another RNN layer.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "gtEd86sX5cB2" - }, - "outputs": [], - "source": [ - "" + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/sequences/text_generation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] } ], diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD index 7bdf9053de749af9d09b12ba7b848e21c1fdb8f0..35d509904211d98f124d2555fc48166e75cb0dd9 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD +++ b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD @@ -28,7 +28,7 @@ py_library( cuda_py_test( name = "l2hmc_test", - size = "large", + size = "medium", srcs = ["l2hmc_test.py"], additional_deps = [ ":l2hmc", @@ -36,4 +36,8 @@ cuda_py_test( "//tensorflow/contrib/eager/python:tfe", "//third_party/py/numpy", ], + shard_count = 4, + tags = [ + "oss_serial", + ], ) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD index 74ce9e84f013d79b3a33ffa79993980b561e366d..30afef83bc5c6c164c8456ed472f4d6064068a25 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD +++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD @@ -9,6 +9,13 @@ py_binary( name = "linear_regression", srcs = ["linear_regression.py"], srcs_version = "PY2AND3", + deps = [":linear_regression_lib"], +) + +py_library( + name = "linear_regression_lib", + srcs = ["linear_regression.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -20,10 +27,13 @@ cuda_py_test( size = "small", srcs = ["linear_regression_test.py"], additional_deps = [ - ":linear_regression", + ":linear_regression_lib", "//tensorflow:tensorflow_py", ], - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = [ + "no_windows", # TODO: needs investigation on Windows + "oss_serial", + ], ) cuda_py_test( @@ -31,7 +41,7 @@ cuda_py_test( size = "small", srcs = ["linear_regression_graph_test.py"], additional_deps = [ - ":linear_regression", + ":linear_regression_lib", "//tensorflow:tensorflow_py", ], ) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 099b712fc06d1d3eb9ab4095f8db7283690bda76..206ef9409df7b1dc21de42ba919d2ba97f334a8c 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -56,7 +56,7 @@ class LinearModel(tf.keras.Model): def mean_square_loss(model, xs, ys): - return tf.reduce_mean(tf.square(tf.subtract(model(xs), ys))) + return tf.reduce_mean(tf.squared_difference(model(xs), ys)) def fit(model, dataset, optimizer, verbose=False, logdir=None): diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 66d52a74943d0d81fde05ce51b019558b327978d..436e887736158ec1ba8e46eac8de4ac7b8e6be01 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -1,11 +1,28 @@ { + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "nmt_with_attention.ipynb", + "version": "0.3.2", + "provenance": [], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "accelerator": "GPU" + }, "cells": [ { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "AOpGoE2T-YXS" }, + "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors.\n", "\n", @@ -13,19 +30,19 @@ "\n", "# Neural Machine Translation with Attention\n", "\n", - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\n", - " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", - "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + "
\n", + "\n", + " Run in Google Colab \n", + "\n", + "View source on GitHub
" ] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "CiwtNgENbx2g" }, + "cell_type": "markdown", "source": [ "This notebook trains a sequence to sequence (seq2seq) model for Spanish to English translation using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). This is an advanced example that assumes some knowledge of sequence to sequence models.\n", "\n", @@ -33,24 +50,22 @@ "\n", "The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n", "\n", - "\u003cimg src=\"https://tensorflow.org/images/spanish-english.png\" alt=\"spanish-english attention plot\"\u003e\n", + "\"spanish-english\n", "\n", "Note: This example takes approximately 10 mintues to run on a single P100 GPU." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "tnxXKDjq3jEL" + "id": "tnxXKDjq3jEL", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "from __future__ import absolute_import, division, print_function\n", "\n", - "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", + "# Import TensorFlow >= 1.10 and enable eager execution\n", "import tensorflow as tf\n", "\n", "tf.enable_eager_execution()\n", @@ -65,14 +80,16 @@ "import time\n", "\n", "print(tf.__version__)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "wfodePkj3jEa" }, + "cell_type": "markdown", "source": [ "## Download and prepare the dataset\n", "\n", @@ -91,14 +108,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "kRVATYOgJs1b" + "id": "kRVATYOgJs1b", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Download the file\n", "path_to_zip = tf.keras.utils.get_file(\n", @@ -106,17 +121,17 @@ " extract=True)\n", "\n", "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\"" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "rd0jw-eC3jEh" + "id": "rd0jw-eC3jEh", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Converts the unicode file to ascii\n", "def unicode_to_ascii(s):\n", @@ -128,7 +143,7 @@ " w = unicode_to_ascii(w.lower().strip())\n", " \n", " # creating a space between a word and the punctuation following it\n", - " # eg: \"he is a boy.\" =\u003e \"he is a boy .\" \n", + " # eg: \"he is a boy.\" => \"he is a boy .\" \n", " # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n", " w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n", " w = re.sub(r'[\" \"]+', \" \", w)\n", @@ -140,19 +155,19 @@ " \n", " # adding a start and an end token to the sentence\n", " # so that the model know when to start and stop predicting.\n", - " w = '\u003cstart\u003e ' + w + ' \u003cend\u003e'\n", + " w = ' ' + w + ' '\n", " return w" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "OHn4Dct23jEm" + "id": "OHn4Dct23jEm", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# 1. Remove the accents\n", "# 2. Clean the sentences\n", @@ -163,20 +178,20 @@ " word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n", " \n", " return word_pairs" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "9xbqO7Iie9bb" + "id": "9xbqO7Iie9bb", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ - "# This class creates a word -\u003e index mapping (e.g,. \"dad\" -\u003e 5) and vice-versa \n", - "# (e.g., 5 -\u003e \"dad\") for each language,\n", + "# This class creates a word -> index mapping (e.g,. \"dad\" -> 5) and vice-versa \n", + "# (e.g., 5 -> \"dad\") for each language,\n", "class LanguageIndex():\n", " def __init__(self, lang):\n", " self.lang = lang\n", @@ -192,23 +207,23 @@ " \n", " self.vocab = sorted(self.vocab)\n", " \n", - " self.word2idx['\u003cpad\u003e'] = 0\n", + " self.word2idx[''] = 0\n", " for index, word in enumerate(self.vocab):\n", " self.word2idx[word] = index + 1\n", " \n", " for word, index in self.word2idx.items():\n", " self.idx2word[index] = word" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "eAY9k49G3jE_" + "id": "eAY9k49G3jE_", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def max_length(tensor):\n", " return max(len(t) for t in tensor)\n", @@ -244,71 +259,71 @@ " padding='post')\n", " \n", " return input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_tar" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "GOi42V79Ydlr" }, + "cell_type": "markdown", "source": [ "### Limit the size of the dataset to experiment faster (optional)\n", "\n", - "Training on the complete dataset of \u003e100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):" + "Training on the complete dataset of >100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "cnxC7q-j3jFD" + "id": "cnxC7q-j3jFD", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Try experimenting with the size of that dataset\n", "num_examples = 30000\n", "input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_targ = load_dataset(path_to_file, num_examples)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "4QILQkOs3jFG" + "id": "4QILQkOs3jFG", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Creating training and validation sets using an 80-20 split\n", "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n", "\n", "# Show length\n", "len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rgCLkfv5uO3d" }, + "cell_type": "markdown", "source": [ "### Create a tf.data dataset" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "TqHsArVZ3jFS" + "id": "TqHsArVZ3jFS", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "BUFFER_SIZE = len(input_tensor_train)\n", "BATCH_SIZE = 64\n", @@ -320,27 +335,29 @@ "\n", "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n", "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "TNfHIF71ulLu" }, + "cell_type": "markdown", "source": [ "## Write the encoder and decoder model\n", "\n", - "Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://www.tensorflow.org/tutorials/seq2seq). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://www.tensorflow.org/tutorials/seq2seq#background_on_the_attention_mechanism) from the seq2seq tutorial. The following diagram shows that each input words is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n", + "Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://github.com/tensorflow/nmt). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://github.com/tensorflow/nmt#background-on-the-attention-mechanism) from the seq2seq tutorial. The following diagram shows that each input word is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n", "\n", - "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\"\u003e\n", + "\"attention\n", "\n", "The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*. \n", "\n", "Here are the equations that are implemented:\n", "\n", - "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\"\u003e\n", - "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\"\u003e\n", + "\"attention\n", + "\"attention\n", "\n", "We're using *Bahdanau attention*. Lets decide on notation before writing the simplified form:\n", "\n", @@ -362,14 +379,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "avyJ_4VIUoHb" + "id": "avyJ_4VIUoHb", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def gru(units):\n", " # If you have a GPU, we recommend using CuDNNGRU(provides a 3x speedup than GRU)\n", @@ -385,17 +400,17 @@ " return_state=True, \n", " recurrent_activation='sigmoid', \n", " recurrent_initializer='glorot_uniform')" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "nZ2rI24i3jFg" + "id": "nZ2rI24i3jFg", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "class Encoder(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n", @@ -412,17 +427,17 @@ " \n", " def initialize_hidden_state(self):\n", " return tf.zeros((self.batch_sz, self.enc_units))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "yJ_B3mhW3jFk" + "id": "yJ_B3mhW3jFk", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "class Decoder(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n", @@ -476,41 +491,41 @@ " \n", " def initialize_hidden_state(self):\n", " return tf.zeros((self.batch_sz, self.dec_units))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "P5UY8wko3jFp" + "id": "P5UY8wko3jFp", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n", "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "_ch_71VbIRfK" }, + "cell_type": "markdown", "source": [ "## Define the optimizer and the loss function" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "WmTHr5iV3jFr" + "id": "WmTHr5iV3jFr", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "optimizer = tf.train.AdamOptimizer()\n", "\n", @@ -519,41 +534,43 @@ " mask = 1 - np.equal(real, 0)\n", " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", " return tf.reduce_mean(loss_)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "DMVWzzsfNl4e" }, + "cell_type": "markdown", "source": [ "## Checkpoints (Object-based saving)" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "Zj8bXQTgNwrF" + "id": "Zj8bXQTgNwrF", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "checkpoint_dir = './training_checkpoints'\n", "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", " encoder=encoder,\n", " decoder=decoder)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "hpObfY22IddU" }, + "cell_type": "markdown", "source": [ "## Training\n", "\n", @@ -567,14 +584,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "ddefjBMa3jF0" + "id": "ddefjBMa3jF0", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "EPOCHS = 10\n", "\n", @@ -592,7 +607,7 @@ " \n", " dec_hidden = enc_hidden\n", " \n", - " dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']] * BATCH_SIZE, 1) \n", + " dec_input = tf.expand_dims([targ_lang.word2idx['']] * BATCH_SIZE, 1) \n", " \n", " # Teacher forcing - feeding the target as the next input\n", " for t in range(1, targ.shape[1]):\n", @@ -625,14 +640,16 @@ " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", " total_loss / N_BATCH))\n", " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "mU3Ce8M6I3rz" }, + "cell_type": "markdown", "source": [ "## Translate\n", "\n", @@ -644,14 +661,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "EbQpyYs13jF_" + "id": "EbQpyYs13jF_", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", " attention_plot = np.zeros((max_length_targ, max_length_inp))\n", @@ -668,7 +683,7 @@ " enc_out, enc_hidden = encoder(inputs, hidden)\n", "\n", " dec_hidden = enc_hidden\n", - " dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']], 0)\n", + " dec_input = tf.expand_dims([targ_lang.word2idx['']], 0)\n", "\n", " for t in range(max_length_targ):\n", " predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)\n", @@ -681,24 +696,24 @@ "\n", " result += targ_lang.idx2word[predicted_id] + ' '\n", "\n", - " if targ_lang.idx2word[predicted_id] == '\u003cend\u003e':\n", + " if targ_lang.idx2word[predicted_id] == '':\n", " return result, sentence, attention_plot\n", " \n", " # the predicted ID is fed back into the model\n", " dec_input = tf.expand_dims([predicted_id], 0)\n", "\n", " return result, sentence, attention_plot" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "s5hQWlbN3jGF" + "id": "s5hQWlbN3jGF", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# function for plotting the attention weights\n", "def plot_attention(attention, sentence, predicted_sentence):\n", @@ -712,17 +727,17 @@ " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n", "\n", " plt.show()" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "sl9zUHzg3jGI" + "id": "sl9zUHzg3jGI", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", " result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)\n", @@ -732,91 +747,93 @@ " \n", " attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n", " plot_attention(attention_plot, sentence.split(' '), result.split(' '))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "n250XbnjOaqP" }, + "cell_type": "markdown", "source": [ "## Restore the latest checkpoint and test" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "UJpT9D5_OgP6" + "id": "UJpT9D5_OgP6", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# restoring the latest checkpoint in checkpoint_dir\n", "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "WrAM0FDomq3E" + "id": "WrAM0FDomq3E", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "translate(u'hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "zSx2iM36EZQZ" + "id": "zSx2iM36EZQZ", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "translate(u'esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "A3LLCx3ZE0Ls" + "id": "A3LLCx3ZE0Ls", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "translate(u'todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "DUQVLVqUE1YW" + "id": "DUQVLVqUE1YW", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# wrong translation\n", "translate(u'trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RTe5P5ioMJwN" }, + "cell_type": "markdown", "source": [ "## Next steps\n", "\n", @@ -824,31 +841,5 @@ "* Experiment with training on a larger dataset, or using more epochs\n" ] } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "nmt_with_attention.ipynb", - "private_outputs": true, - "provenance": [ - { - "file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U", - "timestamp": 1527858391290 - }, - { - "file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv", - "timestamp": 1527776041613 - } - ], - "toc_visible": true, - "version": "0.3.2" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD index f3135a9668fc0dc7faa93a5f119b53f3efd34c6e..f2851d97223e483da11120f1fe3f0a2f641dfb81 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD +++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD @@ -27,7 +27,7 @@ py_library( cuda_py_test( name = "resnet50_test", - size = "large", + size = "medium", srcs = ["resnet50_test.py"], additional_deps = [ ":resnet50", @@ -35,17 +35,19 @@ cuda_py_test( "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "noasan", # Fix b/118130911 "nomsan", # Fix b/118130911 "notsan", # Fix b/118130911 "optonly", + "oss_serial", ], ) cuda_py_test( name = "resnet50_graph_test", - size = "large", + size = "medium", srcs = ["resnet50_graph_test.py"], additional_deps = [ ":resnet50", @@ -53,10 +55,12 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "noasan", "nomsan", "notsan", "optonly", + "oss_serial", ], ) diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD index 4f0d46b1bae3760a63b2abe871034bdedf258f07..cb207b8ddf3641a68a114386f6a95a26ce2b74d6 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/BUILD +++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD @@ -67,30 +67,36 @@ py_library( # Tests cuda_py_test( name = "ops_test", - size = "large", + size = "medium", srcs = ["ops_test.py"], additional_deps = [ ":ops", "//tensorflow:tensorflow_py", ], + shard_count = 4, + tags = [ + "oss_serial", + ], ) cuda_py_test( name = "blocks_test", - size = "large", + size = "medium", srcs = ["blocks_test.py"], additional_deps = [ ":blocks", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ + "no_oss", # b/123045964 "optonly", ], ) cuda_py_test( name = "revnet_test", - size = "large", + size = "medium", srcs = ["revnet_test.py"], additional_deps = [ ":blocks_test", @@ -98,9 +104,11 @@ cuda_py_test( ":revnet", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "no_pip", # depends on blocks_test, which is not available in pip package "optonly", + "oss_serial", ], ) @@ -127,6 +135,13 @@ py_binary( name = "main", srcs = ["main.py"], srcs_version = "PY2AND3", + deps = [":main_lib"], +) + +py_library( + name = "main_lib", + srcs = ["main.py"], + srcs_version = "PY2AND3", deps = [ ":cifar_input", ":config", @@ -141,7 +156,7 @@ py_binary( srcs_version = "PY2AND3", deps = [ ":cifar_input", - ":main", + ":main_lib", ":revnet", "//tensorflow:tensorflow_py", ], @@ -153,7 +168,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":cifar_input", - ":main", + ":main_lib", ":revnet", "//tensorflow:tensorflow_py", ], @@ -165,7 +180,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":cifar_input", - ":main", + ":main_lib", ":revnet", "//tensorflow:tensorflow_py", ], diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py index 1f2cb14972f0b92d29489adff8f94e790e1ec4ed..7406787ba438345dc485c50e347e40597b2037f5 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -96,6 +96,7 @@ class RevNet(tf.keras.Model): def call(self, inputs, training=True): """Forward pass.""" + saved_hidden = None if training: saved_hidden = [inputs] diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD index d500b632ebb97fd12ded3a215b0f1a686194874f..f4dbe7ac16f734f7bee045bc71e9559b630adf81 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD @@ -9,6 +9,13 @@ py_binary( name = "rnn_colorbot", srcs = ["rnn_colorbot.py"], srcs_version = "PY2AND3", + deps = [":rnn_colorbot_lib"], +) + +py_library( + name = "rnn_colorbot_lib", + srcs = ["rnn_colorbot.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -21,8 +28,11 @@ cuda_py_test( name = "rnn_colorbot_test", srcs = ["rnn_colorbot_test.py"], additional_deps = [ - ":rnn_colorbot", + ":rnn_colorbot_lib", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + tags = [ + "oss_serial", + ], ) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 74ebb1ec77131a560b1ebfd062c690920c35e261..1c718a5ce3d8e1541656d92fd5e8dad6c6683c4c 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -207,7 +207,7 @@ class RNNColorbot(tf.keras.Model): def loss(labels, predictions): """Computes mean squared loss.""" - return tf.reduce_mean(tf.square(predictions - labels)) + return tf.reduce_mean(tf.squared_difference(predictions, labels)) def test(model, eval_data): diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD index 2cc2fcbfeb21ee6218d7912d9a93ea2f7b2ea226..43a6ca526d3a0aecda2c8df865a0487ac28758ab 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD @@ -9,6 +9,13 @@ py_binary( name = "rnn_ptb", srcs = ["rnn_ptb.py"], srcs_version = "PY2AND3", + deps = [":rnn_ptb_lib"], +) + +py_library( + name = "rnn_ptb_lib", + srcs = ["rnn_ptb.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", @@ -21,18 +28,22 @@ cuda_py_test( name = "rnn_ptb_test", srcs = ["rnn_ptb_test.py"], additional_deps = [ - ":rnn_ptb", + ":rnn_ptb_lib", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + tags = ["no_oss"], # b/123045964 ) cuda_py_test( name = "rnn_ptb_graph_test", srcs = ["rnn_ptb_graph_test.py"], additional_deps = [ - ":rnn_ptb", + ":rnn_ptb_lib", "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], + tags = [ + "oss_serial", + ], ) diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 15776c694e92825895437a4c1547699f6d9269fb..9b5a2c947b153308c83f1a922d06c034ec5f9ddf 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -128,7 +128,7 @@ class PTBModel(tf.keras.Model): self.linear = layers.Dense( vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1)) - self._output_shape = [-1, embedding_dim] + self._output_shape = [-1, hidden_dim] def call(self, input_seq, training): """Run the forward pass of PTBModel. diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD index 5966f1d4873e8e77b3ad5914da7bfc7e69d4e341..9b0fbaa6793e28d327745767e6ccd3085211ff7d 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/BUILD +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -42,5 +42,6 @@ cuda_py_test( "no-internal-py3", # flaky "no_cuda_on_cpu_tap", "no_pip", # because spinn.py is under third_party/. + "oss_serial", ], ) diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index d18a097063c7d25947af3e2e2959ce574edd553f..3143270ccfe4f670428c80bdc1e09fa452584207 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -37,7 +37,7 @@ from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.training import checkpoint_management -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils # pylint: enable=g-bad-import-order @@ -421,7 +421,7 @@ class SpinnTest(test_util.TensorFlowTestCase): # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) - object_graph = checkpointable_utils.object_metadata( + object_graph = trackable_utils.object_metadata( checkpoint_management.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 566246de4957c1dc5919c10e22146706f9e50be8..b32501c2e804838af9d4c77663be131b77bd30b4 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -32,12 +32,12 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable _to_replace = re.compile("[^A-Za-z0-9.]") -class Metric(checkpointable.CheckpointableBase): +class Metric(trackable.Trackable): """A metric holds state for aggregating statistics over an evaluation run. Example use with eager execution: @@ -269,7 +269,7 @@ class Metric(checkpointable.CheckpointableBase): else: collections = [ops.GraphKeys.LOCAL_VARIABLES] collections += [ops.GraphKeys.METRIC_VARIABLES] - # Variables are Checkpointable dependencies of Metrics regardless of the + # Variables are Trackable dependencies of Metrics regardless of the # global/local distinction. Users can avoid saving variables by not adding a # dependency on the Metric. v = self._add_variable_with_custom_getter( @@ -282,7 +282,7 @@ class Metric(checkpointable.CheckpointableBase): use_resource=True, getter=variable_scope.get_variable, # Raise duplicate variable exceptions from get_variable rather than - # Checkpointable. + # Trackable. overwrite=True) self._vars.append(v) if context.executing_eagerly(): diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 39e5957f5d1760613f2c33607c0bdb163040efb4..c56d1956fde35b562e60496015e666efe9ebc8f6 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -35,7 +35,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils class MetricsTest(test.TestCase): @@ -314,7 +314,7 @@ class MetricsTest(test.TestCase): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") mean = metrics.Mean() - checkpoint = checkpointable_utils.Checkpoint(mean=mean) + checkpoint = trackable_utils.Checkpoint(mean=mean) mean.build() mean._built = True self.evaluate(mean.init_variables()) @@ -327,7 +327,7 @@ class MetricsTest(test.TestCase): self.assertAllEqual(200., self.evaluate(mean.value())) restore_mean = metrics.Mean() - restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean) + restore_checkpoint = trackable_utils.Checkpoint(mean=restore_mean) status = restore_checkpoint.restore(save_path) restore_update = restore_mean(300.) status.assert_consumed().run_restore_ops() diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 240f213c602395b8589d39c3ecd90b602ffa9848..b3e8daddaf2369e9e33179fde2aab1469e97ea47 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils # pylint: disable=not-callable @@ -65,7 +65,7 @@ class NetworkTest(test.TestCase): def test_checkpointing_not_implemented(self): checkpoint_directory = self.get_temp_dir() - checkpoint = checkpointable_utils.Checkpoint(net=MyNetwork()) + checkpoint = trackable_utils.Checkpoint(net=MyNetwork()) with self.assertRaises(NotImplementedError): checkpoint.save(checkpoint_directory) diff --git a/tensorflow/contrib/eager/python/parameter_server.py b/tensorflow/contrib/eager/python/parameter_server.py index 7803a6799bb64441fab881bf6ca986d5cf3851a8..258f0a19309235dcd99b31b4de3d35ef8d89b15b 100644 --- a/tensorflow/contrib/eager/python/parameter_server.py +++ b/tensorflow/contrib/eager/python/parameter_server.py @@ -30,7 +30,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): @@ -129,8 +129,8 @@ class SharedVariable(resource_variable_ops.ResourceVariable): if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") - if isinstance(initial_value, checkpointable.CheckpointInitialValue): - self._maybe_initialize_checkpointable() + if isinstance(initial_value, trackable.CheckpointInitialValue): + self._maybe_initialize_trackable() self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py index 3926de15e71c9917f88fc3f58740b8c75354ab26..f540d9b37b69c7be3b0662b07bd6e9cb8220fadc 100644 --- a/tensorflow/contrib/eager/python/remote_test.py +++ b/tensorflow/contrib/eager/python/remote_test.py @@ -24,12 +24,12 @@ import os import numpy as np from tensorflow.contrib.eager.python import parameter_server -from tensorflow.contrib.eager.python import remote from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function +from tensorflow.python.eager import remote from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 33c988fd9065e7fbe7b9aeb85cad82eb3c119f76..df5b059448f735f7dc1f2963ffbc9c8a8287250a 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -41,6 +41,8 @@ To use, at program startup, call `tf.enable_eager_execution()`. @@add_execution_callback @@clear_execution_callbacks +@@errstate +@@ExecutionCallback @@inf_callback @@inf_nan_callback @@nan_callback @@ -60,7 +62,6 @@ To use, at program startup, call `tf.enable_eager_execution()`. @@Checkpoint @@Checkpointable -@@CheckpointableSaver @@executing_eagerly @@in_eager_mode @@ -97,7 +98,6 @@ from tensorflow.contrib.eager.python.network import Network from tensorflow.contrib.eager.python.network import Sequential from tensorflow.contrib.eager.python.network import save_network_checkpoint from tensorflow.contrib.eager.python.network import restore_network_checkpoint -from tensorflow.contrib.eager.python.remote import connect_to_remote_host from tensorflow.contrib.eager.python.saver import get_optimizer_variables from tensorflow.contrib.eager.python.saver import restore_variables_on_create from tensorflow.contrib.eager.python.saver import Saver @@ -119,10 +119,13 @@ from tensorflow.python.eager.context import set_server_def from tensorflow.python.eager.def_function import function from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks +from tensorflow.python.eager.execution_callbacks import errstate +from tensorflow.python.eager.execution_callbacks import ExecutionCallback from tensorflow.python.eager.execution_callbacks import inf_callback from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr +from tensorflow.python.eager.remote import connect_to_remote_host from tensorflow.python.framework.tensor_spec import TensorSpec from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import enable_eager_execution_internal as enable_remote_eager_execution @@ -134,9 +137,8 @@ from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Vari from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template -from tensorflow.python.training.checkpointable.tracking import Checkpointable -from tensorflow.python.training.checkpointable.util import CheckpointableSaver -from tensorflow.python.training.checkpointable.util import Checkpoint +from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable +from tensorflow.python.training.tracking.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented py_func = script_ops.eager_py_func diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index 8c35dddb5a515aa09cc70c173a9f0605e8567e82..6881fabdc09e3275c29f3013283999c96e283770 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import tempfile from tensorflow.contrib.eager.python import tfe +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -40,6 +41,9 @@ class TFETest(test_util.TensorFlowTestCase): self.assertAllEqual([[4.]], y.numpy()) def testInstantError(self): + if context.num_gpus(): + # TODO(nareshmodi): make this test better + self.skipTest("Gather doesn't do index checking on GPUs") with self.assertRaisesRegexp(errors.InvalidArgumentError, r'indices = 7 is not in \[0, 3\)'): array_ops.gather([0, 1, 2], 7) diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index e344d7a23b55134612aab430b50cf065bd1095e4..da2479a0b7b029561136903c82cabed9aae622b8 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -28,7 +28,6 @@ tf_custom_op_py_library( "python/ops/wals.py", ], dso = [ - ":python/ops/_clustering_ops.so", ":python/ops/_factorization_ops.so", ], kernels = [ @@ -38,12 +37,12 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ ":factorization_ops_test_utils_py", - ":gen_clustering_ops", ":gen_factorization_ops", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", + "//tensorflow/python:clustering_ops_gen", "//tensorflow/python:control_flow_ops", "//tensorflow/python:data_flow_ops", "//tensorflow/python:embedding_ops", @@ -77,17 +76,6 @@ py_library( ], ) -# Ops -tf_custom_op_library( - name = "python/ops/_clustering_ops.so", - srcs = [ - "ops/clustering_ops.cc", - ], - deps = [ - "//tensorflow/contrib/factorization/kernels:clustering_ops", - ], -) - tf_custom_op_library( name = "python/ops/_factorization_ops.so", srcs = [ @@ -100,26 +88,16 @@ tf_custom_op_library( ) tf_gen_op_libs([ - "clustering_ops", "factorization_ops", ]) cc_library( name = "all_ops", deps = [ - ":clustering_ops_op_lib", ":factorization_ops_op_lib", ], ) -tf_gen_op_wrapper_py( - name = "gen_clustering_ops", - out = "python/ops/gen_clustering_ops.py", - deps = [ - ":clustering_ops_op_lib", - ], -) - tf_gen_op_wrapper_py( name = "gen_factorization_ops", out = "python/ops/gen_factorization_ops.py", @@ -131,7 +109,7 @@ tf_gen_op_wrapper_py( # Ops tests tf_py_test( name = "gmm_test", - size = "large", + size = "medium", srcs = [ "python/ops/gmm_test.py", ], @@ -152,6 +130,7 @@ tf_py_test( "//tensorflow/python:random_seed", "//tensorflow/python:training", ], + shard_count = 4, tags = [ "no_pip", # b/38283730 "notsan", # Flaky: b/30756419 @@ -224,10 +203,7 @@ py_test( srcs = ["python/ops/kmeans_test.py"], shard_count = 4, srcs_version = "PY2AND3", - tags = [ - "nomac", # b/73741358 - "notsan", # b/67512932 - ], + tags = ["notsan"], deps = [ ":factorization_py", ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", @@ -249,7 +225,7 @@ py_test( tf_py_test( name = "wals_test", - size = "large", + size = "medium", srcs = ["python/ops/wals_test.py"], additional_deps = [ ":factorization_py", @@ -272,8 +248,8 @@ tf_py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], + shard_count = 4, tags = [ - "manual", "noasan", # times out b/63678675 "nomsan", ], diff --git a/tensorflow/contrib/factorization/kernels/BUILD b/tensorflow/contrib/factorization/kernels/BUILD index ea8b9a17a27093cb57564861815edd6ecb18a014..23d7e088d067effa446e4bcdc9609db612066568 100644 --- a/tensorflow/contrib/factorization/kernels/BUILD +++ b/tensorflow/contrib/factorization/kernels/BUILD @@ -11,7 +11,6 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") cc_library( name = "all_kernels", deps = [ - ":clustering_ops", ":masked_matmul_ops", ":wals_solver_ops", "@protobuf_archive//:protobuf_headers", @@ -29,17 +28,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "clustering_ops", - srcs = ["clustering_ops.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - cc_library( name = "masked_matmul_ops", srcs = ["masked_matmul_ops.cc"], @@ -51,19 +39,3 @@ cc_library( ], alwayslink = 1, ) - -tf_cc_test( - name = "clustering_ops_test", - srcs = ["clustering_ops_test.cc"], - deps = [ - ":clustering_ops", - "//tensorflow/contrib/factorization:clustering_ops_op_lib", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) diff --git a/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc b/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc index a8c5d0763c28ba2b54f217405f0da65533f26b91..68078ba8bbb07b4344c19d554012d214229f9c4f 100644 --- a/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc +++ b/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc @@ -19,12 +19,12 @@ #include #include +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" diff --git a/tensorflow/contrib/factorization/ops/clustering_ops.cc b/tensorflow/contrib/factorization/ops/clustering_ops.cc deleted file mode 100644 index 2686702c1d5768f661dac610c96089eb02e360d7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/factorization/ops/clustering_ops.cc +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2016 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not -// use this file except in compliance with the License. You may obtain a copy -// of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations under -// the License. -// ============================================================================== - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("KmeansPlusPlusInitialization") - .Input("points: float32") - .Input("num_to_sample: int64") - .Input("seed: int64") - .Input("num_retries_per_sample: int64") - .Output("samples: float32") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"( -Selects num_to_sample rows of input using the KMeans++ criterion. - -Rows of points are assumed to be input points. One row is selected at random. -Subsequent rows are sampled with probability proportional to the squared L2 -distance from the nearest row selected thus far till num_to_sample rows have -been sampled. - -points: Matrix of shape (n, d). Rows are assumed to be input points. -num_to_sample: Scalar. The number of rows to sample. This value must not be - larger than n. -seed: Scalar. Seed for initializing the random number generator. -num_retries_per_sample: Scalar. For each row that is sampled, this parameter - specifies the number of additional points to draw from the current - distribution before selecting the best. If a negative value is specified, a - heuristic is used to sample O(log(num_to_sample)) additional points. -samples: Matrix of shape (num_to_sample, d). The sampled rows. -)"); - -REGISTER_OP("KMC2ChainInitialization") - .Input("distances: float32") - .Input("seed: int64") - .Output("index: int64") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"( -Returns the index of a data point that should be added to the seed set. - -Entries in distances are assumed to be squared distances of candidate points to -the already sampled centers in the seed set. The op constructs one Markov chain -of the k-MC^2 algorithm and returns the index of one candidate point to be added -as an additional cluster center. - -distances: Vector with squared distances to the closest previously sampled - cluster center for each candidate point. -seed: Scalar. Seed for initializing the random number generator. -index: Scalar with the index of the sampled point. -)"); - -REGISTER_OP("NearestNeighbors") - .Input("points: float32") - .Input("centers: float32") - .Input("k: int64") - .Output("nearest_center_indices: int64") - .Output("nearest_center_distances: float32") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"( -Selects the k nearest centers for each point. - -Rows of points are assumed to be input points. Rows of centers are assumed to be -the list of candidate centers. For each point, the k centers that have least L2 -distance to it are computed. - -points: Matrix of shape (n, d). Rows are assumed to be input points. -centers: Matrix of shape (m, d). Rows are assumed to be centers. -k: Scalar. Number of nearest centers to return for each point. If k is larger - than m, then only m centers are returned. -nearest_center_indices: Matrix of shape (n, min(m, k)). Each row contains the - indices of the centers closest to the corresponding point, ordered by - increasing distance. -nearest_center_distances: Matrix of shape (n, min(m, k)). Each row contains the - squared L2 distance to the corresponding center in nearest_center_indices. -)"); - -} // namespace tensorflow diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index 84e80791f4991ad2b67d0a00ee1e00cf0d0daadc..d48b89cbacce34781819010addbcbd0ba66f9873 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -18,28 +18,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.factorization.python.ops import gen_clustering_ops -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.factorization.python.ops.gen_clustering_ops import * -# pylint: enable=wildcard-import -from tensorflow.contrib.util import loader from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_clustering_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.embedding_ops import embedding_lookup -from tensorflow.python.platform import resource_loader - -_clustering_ops = loader.load_op_library( - resource_loader.get_path_to_datafile('_clustering_ops.so')) +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.python.ops.gen_clustering_ops import * +# pylint: enable=wildcard-import # Euclidean distance between vectors U and V is defined as \\(||U - V||_F\\) # which is the square root of the sum of the absolute squares of the elements diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index d365ad111760247fc18b730657390f07ba6b865e..9f0664dfe5ba7a098b6976388d1cf737dafb4842 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -314,8 +314,7 @@ class GmmAlgorithm(object): # reparametrization of variance parameters. det_expanded = math_ops.reduce_sum( math_ops.log(self._covs + 1e-3), 1, keepdims=True) - diff = shard - self._means - x2 = math_ops.square(diff) + x2 = math_ops.squared_difference(shard, self._means) cov_expanded = array_ops.expand_dims(1.0 / (self._covs + 1e-3), 2) # num_classes X num_examples x2_cov = math_ops.matmul(x2, cov_expanded) diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 1cd83bdb5de7c2f6dc91c980750b49aca1a7790b..0a9199d61f36f10c98b95d79ece7e86765d2db0e 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -6,7 +6,7 @@ package( licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_test") py_library( name = "feature_column_py", @@ -14,7 +14,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":sequence_feature_column", - ":sequence_feature_column_v2", "//tensorflow/python:util", ], ) @@ -37,13 +36,13 @@ py_library( ], ) -py_test( +tf_py_test( name = "sequence_feature_column_test", srcs = ["python/feature_column/sequence_feature_column_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ + additional_deps = [ ":sequence_feature_column", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", @@ -53,17 +52,14 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python/feature_column:feature_column_py", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", ], + tags = ["no_pip"], ) -py_test( +tf_py_test( name = "sequence_feature_column_integration_test", srcs = ["python/feature_column/sequence_feature_column_integration_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ + additional_deps = [ ":sequence_feature_column", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", @@ -73,46 +69,5 @@ py_test( "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/keras:layers", ], -) - -py_library( - name = "sequence_feature_column_v2", - srcs = ["python/feature_column/sequence_feature_column_v2.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/feature_column:feature_column_py", - ], -) - -py_test( - name = "sequence_feature_column_v2_test", - srcs = ["python/feature_column/sequence_feature_column_v2_test.py"], - srcs_version = "PY2AND3", tags = ["no_pip"], - deps = [ - ":sequence_feature_column", - ":sequence_feature_column_v2", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//tensorflow/python/feature_column", - "//tensorflow/python/feature_column:feature_column_py", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], ) diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index dad50a3a73085526f65bd87c3d8549ceb75b3af4..8fd2b5f39bc88b76fe5583f8d18389e232ea9f40 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -32,7 +32,6 @@ tf_custom_op_py_library( "python/ops/arg_scope.py", "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", - "python/ops/critical_section_ops.py", "python/ops/ops.py", "python/ops/prettyprint_ops.py", "python/ops/script_ops.py", @@ -50,6 +49,8 @@ tf_custom_op_py_library( visibility = [ "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", + "//tensorflow_estimator:__subpackages__", + "//tensorflow_model_optimization:__subpackages__", "//video/youtube/personalization:__subpackages__", ], deps = [ @@ -170,26 +171,6 @@ py_test( ], ) -cuda_py_test( - name = "critical_section_test", - size = "medium", - srcs = ["python/ops/critical_section_test.py"], - additional_deps = [ - "//tensorflow/python:client_testlib", - ":framework_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:platform_test", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:tensor_array_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:context", - ], -) - py_test( name = "ops_test", size = "small", diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index e72e50585a3861d4527b66f89e1659d76c85960a..063717f08aa88f4de9470d8392db2b7c95b3e4bf 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -94,8 +94,6 @@ @@smart_constant_value @@smart_case -@@CriticalSection - @@BoundedTensorSpec @@TensorSpec @@ -129,18 +127,24 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['nest'] _nest_allowed_symbols = [ 'assert_same_structure', + 'is_nested', 'is_sequence', + 'is_sequence_or_composite', 'flatten', 'flatten_dict_items', 'pack_sequence_as', 'map_structure', 'map_structure_with_paths', + 'map_structure_with_tuple_paths', 'assert_shallow_structure', 'flatten_up_to', + 'flatten_with_tuple_paths_up_to', 'map_structure_up_to', + 'map_structure_with_tuple_paths_up_to', 'get_traverse_shallow_structure', 'yield_flat_paths', 'flatten_with_joined_string_paths', + 'flatten_with_tuple_paths', ] remove_undocumented(nest.__name__, allowed_exception_list=_nest_allowed_symbols) diff --git a/tensorflow/contrib/framework/python/ops/__init__.py b/tensorflow/contrib/framework/python/ops/__init__.py index c4976497f5fa95d82e492153b117681f693eaa13..8113bf7c095bd0817e40cfd08bdf1ef7275ba55b 100644 --- a/tensorflow/contrib/framework/python/ops/__init__.py +++ b/tensorflow/contrib/framework/python/ops/__init__.py @@ -22,7 +22,6 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.framework.python.ops.arg_scope import * from tensorflow.contrib.framework.python.ops.checkpoint_ops import * -from tensorflow.contrib.framework.python.ops.critical_section_ops import * from tensorflow.contrib.framework.python.ops.ops import * from tensorflow.contrib.framework.python.ops.prettyprint_ops import * from tensorflow.contrib.framework.python.ops.script_ops import * diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index 57a5bfbf43c915775c6b0ef05baac19581213a09..f65f450eba49163c319af54ec2bd7f6b61e34c1e 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -171,6 +171,7 @@ cuda_py_test( main = "python/ops/fused_conv2d_bias_activation_benchmark.py", tags = [ "manual", # TODO(b/117128481): re-enable after fixing OSS build + "nogpu", "requires-gpu-sm70", ], ) diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 93b1aaa85e88e00c1b12a388321a4d6fb10f1611..f13a66717f67a1a627f66af9468c6f2897aaf7a4 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -19,13 +19,13 @@ limitations under the License. #include "tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" @@ -522,7 +522,7 @@ void LaunchFusedConv2DBiasActivationOp:: auto bias_ptr = AsDeviceMemory(bias.template flat().data(), bias.template flat().size()); - static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( + static int64 ConvolveScratchSize = GetDnnWorkspaceLimit( // default value is in bytes despite the name of the environment variable "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB ); @@ -565,12 +565,26 @@ void LaunchFusedConv2DBiasActivationOp:: fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo( stream->parent()), &algorithms)); + if (activation_mode == ActivationMode::NONE) { + // Only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM is supported for + // identity activation, other algs seem to quietly do Relu. + // See + // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBiasActivationForward + algorithms.erase( + std::remove_if( + algorithms.begin(), algorithms.end(), + [](dnn::AlgorithmDesc alg) { + return alg.algo_id() != + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + }), + algorithms.end()); + } dnn::ProfileResult best_result; dnn::ProfileResult best_result_no_scratch; for (auto profile_algorithm : algorithms) { // TODO(zhengxq): profile each algorithm multiple times to better // accuracy. - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); dnn::ProfileResult profile_result; bool cudnn_launch_status = stream @@ -609,7 +623,7 @@ void LaunchFusedConv2DBiasActivationOp:: algorithm_config); } - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); bool cudnn_launch_status = stream ->ThenFusedConvolveWithAlgorithm( diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index f89d7ed0f45f919b17398de5d9449d12c08dd2f2..386e4cf69b7aa118a85fb25bcb809a879c5c1bd8 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -1,12 +1,14 @@ -# Files for using TFGAN framework. -package(default_visibility = ["//tensorflow:__subpackages__"]) +# Files for using TF-GAN framework. +load("//tensorflow:tensorflow.bzl", "py_test") + +package(default_visibility = [ + "//tensorflow:__subpackages__", +]) licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test") - py_library( name = "gan", srcs = [ @@ -104,7 +106,9 @@ py_library( deps = [ ":gan_estimator", ":head", + ":latent_gan_estimator", ":stargan_estimator", + ":tpu_gan_estimator", "//tensorflow/python:util", ], ) @@ -128,6 +132,7 @@ py_library( ":clip_weights", ":conditioning_utils", ":random_tensor_pool", + ":spectral_normalization", ":virtual_batchnorm", "//tensorflow/python:util", ], @@ -141,16 +146,15 @@ py_library( "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", "//tensorflow/python:clip_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", + "//tensorflow/python:gradients_impl", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:summary", "//tensorflow/python:tensor_util", "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/losses", - "//third_party/py/numpy", ], ) @@ -373,7 +377,10 @@ py_test( name = "classifier_metrics_test", srcs = ["python/eval/python/classifier_metrics_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = [ + "no_pip", + "no_windows", + ], deps = [ ":classifier_metrics", "//tensorflow/core:protos_all_py", @@ -518,15 +525,19 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:parsing_ops", "//tensorflow/python:summary", + "//tensorflow/python:tensor_shape", "//tensorflow/python:training", "//tensorflow/python:training_util", "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:numpy_io", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", "@six_archive//:six", @@ -562,28 +573,114 @@ py_test( deps = [ ":namedtuples", ":stargan_estimator", - ":tuple_losses", "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/learn", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", - "//tensorflow/python:parsing_ops", "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:training_util", "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:numpy_io", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + "@six_archive//:six", + ], +) + +py_library( + name = "tpu_gan_estimator", + srcs = [ + "python/estimator/python/tpu_gan_estimator.py", + "python/estimator/python/tpu_gan_estimator_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":gan_estimator", + ":namedtuples", + ":train", + "//tensorflow/contrib/tpu:tpu_estimator", + "//tensorflow/contrib/tpu:tpu_lib", + "//tensorflow/contrib/training:training_py", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:util", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/ops/losses", + ], +) + +py_test( + name = "tpu_gan_estimator_test", + srcs = ["python/estimator/python/tpu_gan_estimator_test.py"], + shard_count = 11, + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":namedtuples", + ":tpu_gan_estimator", + ":tuple_losses", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/tpu:tpu_estimator", + "//tensorflow/contrib/tpu:tpu_lib", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:summary", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:training", + "//tensorflow/python:training_util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", "@six_archive//:six", ], ) +py_library( + name = "latent_gan_estimator", + srcs = [ + "python/estimator/python/latent_gan_estimator.py", + "python/estimator/python/latent_gan_estimator_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":train", + "//tensorflow/python:clip_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:random_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training_util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_test( + name = "latent_gan_estimator_test", + srcs = [ + "python/estimator/python/latent_gan_estimator_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":latent_gan_estimator", + "//tensorflow/python:array_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:run_config", + "//tensorflow/python/ops/losses", + ], +) + py_library( name = "sliced_wasserstein", srcs = [ @@ -618,3 +715,45 @@ py_test( "//third_party/py/numpy", ], ) + +py_library( + name = "spectral_normalization", + srcs = [ + "python/features/python/spectral_normalization.py", + "python/features/python/spectral_normalization_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:standard_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/keras:engine", + ], +) + +py_test( + name = "spectral_normalization_test", + srcs = ["python/features/python/spectral_normalization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":spectral_normalization", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/slim", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/keras:layers", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md index 9ab86329eaf0e6fd426aef1f552f4e27c2ad65de..4eac4e80cdacd779fdbedef19e4a654196f0caf1 100644 --- a/tensorflow/contrib/gan/README.md +++ b/tensorflow/contrib/gan/README.md @@ -1,14 +1,15 @@ -# TensorFlow-GAN (TFGAN) + +# TensorFlow-GAN (TF-GAN) -TFGAN is a lightweight library for training and evaluating Generative +TF-GAN is a lightweight library for training and evaluating Generative Adversarial Networks (GANs). This technique allows you to train a network (called the 'generator') to sample from a distribution, without having to explicitly model the distribution and without writing an explicit loss. For example, the generator could learn to draw samples from the distribution of natural images. For more details on this technique, see ['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by -Goodfellow et al. See [tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) for examples, and [this tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an +Goodfellow et al. See [tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) for examples, and [this tutorial](http://https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an introduction. #### Usage @@ -17,27 +18,27 @@ import tensorflow as tf tfgan = tf.contrib.gan ``` -## Why TFGAN? +## Why TF-GAN? * Easily train generator and discriminator networks with well-tested, flexible [library calls](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py). You can -mix TFGAN, native TF, and other custom frameworks +mix TF-GAN, native TF, and other custom frameworks * Use already implemented [GAN losses and penalties](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/losses_impl.py) (ex Wasserstein loss, gradient penalty, mutual information penalty, etc) * [Monitor and visualize](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/summaries_impl.py) GAN progress during training, and [evaluate](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py) them * Use already-implemented [tricks](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/) to stabilize and improve training * Develop based on examples of [common GAN setups](https://github.com/tensorflow/models/tree/master/research/gan/) -* Use the TFGAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model -* Improvements in TFGAN infrastructure will automatically benefit your TFGAN project +* Use the TF-GAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model +* Improvements in TF-GAN infrastructure will automatically benefit your TF-GAN project * Stay up-to-date with research as we add more algorithms -## What are the TFGAN components? +## What are the TF-GAN components? -TFGAN is composed of several parts which were design to exist independently. +TF-GAN is composed of several parts which were design to exist independently. These include the following main pieces (explained in detail below). * [core](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py): provides the main infrastructure needed to train a GAN. Training occurs in four phases, and each phase can be completed by custom-code or by using a - TFGAN library call. + TF-GAN library call. * [features](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/): Many common GAN operations and normalization techniques are implemented for @@ -56,14 +57,14 @@ These include the following main pieces (explained in detail below). generative models. * [examples](https://github.com/tensorflow/models/tree/master/research/gan/) - and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TFGAN to make - GAN training easier, or use the more complicated examples to jumpstart your + and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TF-GAN to make + GAN training easier, or use the more complicated examples to jump-start your own project. These include unconditional and conditional GANs, InfoGANs, adversarial losses on existing networks, and image-to-image translation. ## Training a GAN model -Training in TFGAN typically consists of the following steps: +Training in TF-GAN typically consists of the following steps: 1. Specify the input to your networks. 1. Set up your generator and discriminator using a `GANModel`. @@ -71,12 +72,12 @@ Training in TFGAN typically consists of the following steps: 1. Create your train ops using a `GANTrainOps`. 1. Run your train ops. -At each stage, you can either use TFGAN's convenience functions, or you can +At each stage, you can either use TF-GAN's convenience functions, or you can perform the step manually for fine-grained control. We provide examples below. There are various types of GAN setups. For instance, you can train a generator to sample unconditionally from a learned distribution, or you can condition on -extra information such as a class label. TFGAN is compatible with many setups, +extra information such as a class label. TF-GAN is compatible with many setups, and we demonstrate a few below: ### Examples @@ -254,9 +255,9 @@ with variable_scope.variable_scope(dis_scope, reuse=True): discriminator_real_outputs = discriminator_fn(images) generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables(dis_scope) -# Depending on what TFGAN features you use, you don't always need to supply +# Depending on what TF-GAN features you use, you don't always need to supply # every `GANModel` field. At a minimum, you need to include the discriminator -# outputs and variables if you want to use TFGAN to construct losses. +# outputs and variables if you want to use TF-GAN to construct losses. gan_model = tfgan.GANModel( generator_inputs, generated_data, diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py index f1946c7f925660eae3aaa650c437e03da1f33d6c..1e6000898f7b8a53ad3f6fa12deebd54bf3a57ff 100644 --- a/tensorflow/contrib/gan/__init__.py +++ b/tensorflow/contrib/gan/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN is a lightweight library for training and evaluating GANs. +"""TF-GAN is a lightweight library for training and evaluating GANs. In addition to providing the infrastructure for easily training and evaluating GANS, this library contains modules for a TFGAN-backed Estimator, @@ -24,7 +24,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# Collapse TFGAN into a tiered namespace. +# Collapse TF-GAN into a tiered namespace. from tensorflow.contrib.gan.python import estimator from tensorflow.contrib.gan.python import eval # pylint:disable=redefined-builtin from tensorflow.contrib.gan.python import features diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py index 99d38011ba677f03e198a431634fbb2ce349f912..430266555b723e6ca39dccffc1442dbef5d4a385 100644 --- a/tensorflow/contrib/gan/python/estimator/__init__.py +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN estimator module. +"""TF-GAN estimator module. GANEstimator provides all the infrastructure support of a TensorFlow Estimator -with the feature support of TFGAN. +with the feature support of TF-GAN. """ from __future__ import absolute_import @@ -26,18 +26,25 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.gan.python.estimator.python import gan_estimator from tensorflow.contrib.gan.python.estimator.python import head +from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator from tensorflow.contrib.gan.python.estimator.python import stargan_estimator +from tensorflow.contrib.gan.python.estimator.python import tpu_gan_estimator from tensorflow.contrib.gan.python.estimator.python.gan_estimator import * from tensorflow.contrib.gan.python.estimator.python.head import * +from tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator import * from tensorflow.contrib.gan.python.estimator.python.stargan_estimator import * +from tensorflow.contrib.gan.python.estimator.python.tpu_gan_estimator import * # pylint: enable=unused-import,wildcard-import from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = [ +_allowed_symbols = ([ 'gan_estimator', 'stargan_estimator', + 'tpu_gan_estimator', + 'latent_gan_estimator', 'head', -] + gan_estimator.__all__ + stargan_estimator.__all__ + head.__all__ +] + gan_estimator.__all__ + stargan_estimator.__all__ + head.__all__ + + tpu_gan_estimator.__all__ + latent_gan_estimator.__all__) remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 3593b501bb738b8f58dce4e40cffbdf410f136b3..dd904611d1a3bb78de8316d5ed29ab0f800f29a9 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A TFGAN-backed GAN Estimator.""" +"""A TF-GAN-backed GAN Estimator.""" from __future__ import absolute_import from __future__ import division @@ -56,10 +56,10 @@ _summary_type_map = { class GANEstimator(estimator.Estimator): """An estimator for Generative Adversarial Networks (GANs). - This Estimator is backed by TFGAN. The network functions follow the TFGAN API - except for one exception: if either `generator_fn` or `discriminator_fn` have - an argument called `mode`, then the tf.Estimator mode is passed in for that - argument. This helps with operations like batch normalization, which have + This Estimator is backed by TF-GAN. The network functions follow the TF-GAN + API except for one exception: if either `generator_fn` or `discriminator_fn` + have an argument called `mode`, then the tf.Estimator mode is passed in for + that argument. This helps with operations like batch normalization, which have different train and evaluation behavior. Example: @@ -68,7 +68,7 @@ class GANEstimator(estimator.Estimator): import tensorflow as tf tfgan = tf.contrib.gan - # See TFGAN's `train.py` for a description of the generator and + # See TF-GAN's `train.py` for a description of the generator and # discriminator API. def generator_fn(generator_inputs): ... @@ -123,13 +123,13 @@ class GANEstimator(estimator.Estimator): to continue training a previously saved model. generator_fn: A python function that takes a Tensor, Tensor list, or Tensor dictionary as inputs and returns the outputs of the GAN - generator. See `TFGAN` for more details and examples. Additionally, if + generator. See `TF-GAN` for more details and examples. Additionally, if it has an argument called `mode`, the Estimator's `mode` will be passed in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch normalization. discriminator_fn: A python function that takes the output of `generator_fn` or real data in the GAN setup, and `generator_inputs`. - Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details + Outputs a Tensor in the range [-inf, inf]. See `TF-GAN` for more details and examples. generator_loss_fn: The loss function on the generator. Takes a `GANModel` tuple. @@ -233,13 +233,14 @@ def _get_estimator_spec( estimator_spec = _get_eval_estimator_spec( gan_model, gan_loss, get_eval_metric_ops_fn) else: # model_fn_lib.ModeKeys.TRAIN: - gopt = (generator_optimizer() if callable(generator_optimizer) else - generator_optimizer) - dopt = (discriminator_optimizer() if callable(discriminator_optimizer) - else discriminator_optimizer) + if callable(generator_optimizer): + generator_optimizer = generator_optimizer() + if callable(discriminator_optimizer): + discriminator_optimizer = discriminator_optimizer() get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() estimator_spec = _get_train_estimator_spec( - gan_model, gan_loss, gopt, dopt, get_hooks_fn, is_chief=is_chief) + gan_model, gan_loss, generator_optimizer, discriminator_optimizer, + get_hooks_fn, is_chief=is_chief) return estimator_spec diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index bc9021050bc010ce75c3091fef868549686c0e90..66af79d1e81bbc450141673dd54d865e5c7932d5 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN's estimator.py.""" +"""Tests for TF-GAN's estimator.py.""" from __future__ import absolute_import from __future__ import division @@ -23,7 +23,6 @@ import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib import layers from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples @@ -75,8 +74,8 @@ class GetGANModelTest(test.TestCase, parameterized.TestCase): def test_get_gan_model(self, mode): with ops.Graph().as_default(): generator_inputs = {'x': array_ops.ones([3, 4])} - real_data = (array_ops.zeros([3, 4]) if - mode != model_fn_lib.ModeKeys.PREDICT else None) + is_predict = mode == model_fn_lib.ModeKeys.PREDICT + real_data = array_ops.zeros([3, 4]) if not is_predict else None gan_model = estimator._get_gan_model( mode, generator_fn, discriminator_fn, real_data, generator_inputs, add_summaries=False) @@ -139,6 +138,7 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): @classmethod def setUpClass(cls): + super(GetEstimatorSpecTest, cls).setUpClass() cls._generator_optimizer = training.GradientDescentOptimizer(1.0) cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0) @@ -200,7 +200,6 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) -# TODO(joelshor): Add pandas test. class GANEstimatorIntegrationTest(test.TestCase): def setUp(self): @@ -231,19 +230,19 @@ class GANEstimatorIntegrationTest(test.TestCase): get_eval_metric_ops_fn=get_metrics, model_dir=self._model_dir) - # TRAIN + # Train. num_steps = 10 est.train(train_input_fn, steps=num_steps) - # EVALUTE + # Evaluate. scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', six.iterkeys(scores)) + self.assertIn('loss', scores) self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], scores['loss']) - self.assertIn('mse_custom_metric', six.iterkeys(scores)) + self.assertIn('mse_custom_metric', scores) - # PREDICT + # Predict. predictions = np.array([x for x in est.predict(predict_input_fn)]) self.assertAllEqual(prediction_size, predictions.shape) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 1a0ee6dfc498eb6dc8c97411589d9e35bc352062..cbe990b476c3b17ce61e0826b17d10976fea43c7 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A TFGAN-backed GAN Estimator.""" +"""A TF-GAN-backed GAN Estimator.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 8205bc889dc01c8680e2139393d65723280cfbd0..5b50234a0e33cd297b176f142b358338966b6758 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN's head.py.""" +"""Tests for TF-GAN's head.py.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..4e164e24168bb0cc5e9a7cc772081781ea088bb1 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`tf.Learn` components for `Train Input Estimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = latent_gan_estimator_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..f5afc7731937ed1a82c8ebb5969b2687ffdd583b --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py @@ -0,0 +1,205 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements an estimator wrapper that allows training the input latent space. + +This file implements a latent gan estimator that wraps around a previously +trained GAN. The latent gan estimator trains a single variable z, representing +the hidden latent distribution that is the 'noise' input to the GAN. By training +z, the inpainting estimator can move around the latent z space towards +minimizing a specific loss function. + +The latent gan estimator has a few key differences from a normal estimator. + +First: the variables in the estimator should not be saved, as we are not +updating the original GAN and are only adding a new z variable that is meant +to be different for each run. In order to do distributed training using +train_and_evaluate, the Tensorflow RunConfig is expected to save checkpoints +by having either save_checkpoints_steps or save_checkpoints_secs saved. +To avoid this conflict, we purposely set the save_checkpoints_steps value in +the RunConfig to be one step more than the total number of steps that the +inpainter estimator will run. + +Second: we need to specify warm start settings, as we are reloading the +GAN model into a different graph (specifically, one with a new z variable). +The warm start settings defined below reload all GAN variables and ignore the +new z variable (and the optimizer). + +Usage: + + def _generator(net, mode): + ... + + def _discriminator(net, condition, mode): + ... + + def _loss(gan_model, features, labels, add_summaries): + ... + + def optimizer(): + ... + + params = {} + config = tf.estimator.RunConfig() + tmp_dir = path/to/output/storage + + estimator = latent_gan_estimator.get_latent_gan_estimator( + _generator, _discriminator, _loss, optimizer, params, config, tmp_dir) + + def input_fn(): + ... + + estimator.train(input_fn=input_fn) + +See latent_gan_estimator_test.py or tensorflow_models/gan/face_inpainting for +further examples. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary +from tensorflow.python.training import training_util + + +INPUT_NAME = 'new_var_z_input' # The name for the new z space input variable. +OPTIMIZER_NAME = 'latent_gan_optimizer' # The name for the new optimizer vars. + +__all__ = [ + 'get_latent_gan_estimator', +] + + +def _get_latent_gan_model_fn(generator_fn, discriminator_fn, loss_fn, + optimizer): + """Sets up a model function that wraps around a given GAN.""" + def model_fn(features, labels, mode, params): + """Model function defining an inpainting estimator.""" + batch_size = params['batch_size'] + z_shape = [batch_size] + params['z_shape'] + add_summaries = params['add_summaries'] + input_clip = params['input_clip'] + + z = variable_scope.get_variable( + name=INPUT_NAME, initializer=random_ops.truncated_normal(z_shape), + constraint=lambda x: clip_ops.clip_by_value(x, -input_clip, input_clip)) + + generator = functools.partial(generator_fn, mode=mode) + discriminator = functools.partial(discriminator_fn, mode=mode) + gan_model = tfgan_train.gan_model(generator_fn=generator, + discriminator_fn=discriminator, + real_data=labels, + generator_inputs=z, + check_shapes=False) + + loss = loss_fn(gan_model, features, labels, add_summaries) + + # Use a variable scope to make sure that estimator variables dont cause + # save/load problems when restoring from ckpts. + with variable_scope.variable_scope(OPTIMIZER_NAME): + opt = optimizer(learning_rate=params['learning_rate'], + **params['opt_kwargs']) + train_op = opt.minimize( + loss=loss, global_step=training_util.get_or_create_global_step(), + var_list=[z]) + + if add_summaries: + z_grads = gradients_impl.gradients(loss, z) + summary.scalar('z_loss/z_grads', clip_ops.global_norm(z_grads)) + summary.scalar('z_loss/loss', loss) + + return model_fn_lib.EstimatorSpec(mode=mode, + predictions=gan_model.generated_data, + loss=loss, + train_op=train_op) + return model_fn + + +def get_latent_gan_estimator(generator_fn, discriminator_fn, loss_fn, + optimizer, params, config, ckpt_dir, + warmstart_options=True): + """Gets an estimator that passes gradients to the input. + + This function takes in a generator and adds a trainable z variable that is + used as input to this generator_fn. The generator itself is treated as a black + box through which gradients can pass through without updating any weights. The + result is a trainable way to traverse the GAN latent space. The loss_fn is + used to actually train the z variable. The generator_fn and discriminator_fn + should be previously trained by the tfgan library (on reload, the variables + are expected to follow the tfgan format. It may be possible to use the + latent gan estimator with entirely custom GANs that do not use the tfgan + library as long as the appropriate variables are wired properly). + + Args: + generator_fn: a function defining a Tensorflow graph for a GAN generator. + The weights defined in this graph should already be defined in the given + checkpoint location. Should have 'mode' as an argument. + discriminator_fn: a function defining a Tensorflow graph for a GAN + discriminator. Should have 'mode' as an argument. + loss_fn: a function defining a Tensorflow graph for a GAN loss. Takes in a + GANModel tuple, features, labels, and add_summaries as inputs. + optimizer: a tf.Optimizer or a function that returns a tf.Optimizer with no + inputs. + params: An object containing the following parameters: + - batch_size: an int indicating the size of the training batch. + - z_shape: the desired shape of the input z values (not counting batch). + - learning_rate: a scalar or function defining a learning rate applied to + optimizer. + - input_clip: the amount to clip the x training variable by. + - add_summaries: whether or not to add summaries. + - opt_kwargs: optimizer kwargs. + config: tf.RunConfig. Should point model to output dir and should indicate + whether to save checkpoints (to avoid saving checkpoints, set + save_checkpoints_steps to a number larger than the number of train steps). + The model_dir field in the RunConfig should point to a directory WITHOUT + any saved checkpoints. + ckpt_dir: the directory where the model checkpoints live. The checkpoint is + used to warm start the underlying GAN. This should NOT be the same as + config.model_dir. + warmstart_options: boolean, None, or a WarmStartSettings object. If set to + True, uses a default WarmStartSettings object. If set to False or None, + does not use warm start. If using a custom WarmStartSettings object, make + sure that new variables are properly accounted for when reloading the + underlying GAN. Defaults to True. + Returns: + An estimator spec defining a GAN input training estimator. + """ + model_fn = _get_latent_gan_model_fn(generator_fn, discriminator_fn, + loss_fn, optimizer) + + if isinstance(warmstart_options, estimator.WarmStartSettings): + ws = warmstart_options + elif warmstart_options: + # Default WarmStart loads all variable names except INPUT_NAME and + # OPTIMIZER_NAME. + var_regex = '^(?!.*(%s|%s).*)' % (INPUT_NAME, OPTIMIZER_NAME) + ws = estimator.WarmStartSettings(ckpt_to_initialize_from=ckpt_dir, + vars_to_warm_start=var_regex) + else: + ws = None + + if 'opt_kwargs' not in params: + params['opt_kwargs'] = {} + + return estimator.Estimator(model_fn=model_fn, config=config, params=params, + warm_start_from=ws) diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ac139e532e35f7aae6da0655103a7249fe3382d4 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py @@ -0,0 +1,119 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for latent_gan_estimator. + +See g3.tp.tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import numpy as np +from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator +from tensorflow.python.estimator import run_config as run_config +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import test +from tensorflow.python.training import training + + +class TrainInputEstimatorTest(test.TestCase): + + def test_get_input_training_estimator(self): + """Integration test to make sure the input_training_estimator works.""" + + # Create dummy test input tensors. + true_features = np.reshape(np.random.uniform(size=100), (10, 10)) + true_labels = np.reshape(np.random.uniform(size=100), (5, 20)) + expected_z_output = [[1, -1], [-1, 1]] + + # Fill out required parameters randomly, includes optimizer kwargs. + params = { + 'batch_size': 2, + 'z_shape': [2], + 'learning_rate': 1.0, + 'input_clip': 1.0, + 'add_summaries': False, + 'opt_kwargs': { + 'beta1': 0.1 + } + } + + input_z_shape = [params['batch_size']] + params['z_shape'] + + # Create dummy model functions that represent an underlying GANEstimator and + # the input training wrapper. Make sure that everything is wired up + # correctly in the internals of each dummy function. + def _generator(net, mode): + """The generator function will get the newly created z variable.""" + del mode + self.assertSequenceEqual(net.shape, input_z_shape) + gen_dummy_var = variable_scope.get_variable( + name='generator_dummy_variable', + initializer=array_ops.ones(input_z_shape)) + return net * gen_dummy_var + + def _discriminator(net, condition, mode): + """The discriminator function will get either the z variable or labels.""" + del condition, mode + try: + self.assertSequenceEqual(net.shape, true_labels.shape) + except AssertionError: + self.assertSequenceEqual(net.shape, input_z_shape) + return net + + def _loss(gan_model, features, labels, _): + """Make sure that features and labels are passed in from input.""" + self.assertTrue(np.array_equal(features, true_features)) + self.assertTrue(np.array_equal(labels, true_labels)) + return losses.absolute_difference(expected_z_output, + gan_model.generated_data) + + optimizer = training.AdamOptimizer + + # We are not loading checkpoints, so set the corresponding directory to a + # dummy directories. + tmp_dir = tempfile.mkdtemp() + config = run_config.RunConfig(model_dir=tmp_dir, + save_summary_steps=None, + save_checkpoints_steps=1, + save_checkpoints_secs=None) + + # Get the estimator. Disable warm start so that there is no attempted + # checkpoint reloading. + estimator = latent_gan_estimator.get_latent_gan_estimator( + _generator, _discriminator, _loss, optimizer, params, config, tmp_dir, + warmstart_options=None) + + # Train for a few steps. + def dummy_input(): + return true_features, true_labels + estimator.train(input_fn=dummy_input, steps=10) + + # Make sure the generator variables did not change, but the z variables did + # change. + self.assertTrue(np.array_equal( + estimator.get_variable_value('Generator/generator_dummy_variable'), + np.ones(input_z_shape))) + self.assertTrue(np.array_equal( + estimator.get_variable_value('new_var_z_input'), + expected_z_output)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py index f60e16bc04662b33bc0bb22b5acc8c7fcc7a03ba..2a485e7d47ff10cf34c1b44f8dcc6b1f33c9a05f 100644 --- a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A TFGAN-backed StarGAN Estimator.""" +"""A TF-GAN-backed StarGAN Estimator.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py index 2ec7938c7c4051842c7e982b54c1213b6e841b79..0fcd1b7924eb02f5d617b45af16852baf2e2bb48 100644 --- a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN's stargan_estimator.py.""" +"""Tests for TF-GAN's stargan_estimator.py.""" from __future__ import absolute_import from __future__ import division @@ -23,7 +23,6 @@ import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib import layers from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples @@ -80,7 +79,7 @@ class StarGetGANModelTest(test.TestCase, parameterized.TestCase): self.assertEqual(input_data, gan_model.input_data) self.assertIsNotNone(gan_model.generated_data) self.assertIsNotNone(gan_model.generated_data_domain_target) - self.assertEqual(1, len(gan_model.generator_variables)) + self.assertLen(gan_model.generator_variables, 1) self.assertIsNotNone(gan_model.generator_scope) self.assertIsNotNone(gan_model.generator_fn) if mode == model_fn_lib.ModeKeys.PREDICT: @@ -109,7 +108,7 @@ class StarGetGANModelTest(test.TestCase, parameterized.TestCase): gan_model.discriminator_input_data_domain_predication) self.assertIsNotNone( gan_model.discriminator_generated_data_domain_predication) - self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer + self.assertLen(gan_model.discriminator_variables, 2) # 1 FC layer self.assertIsNotNone(gan_model.discriminator_scope) self.assertIsNotNone(gan_model.discriminator_fn) @@ -163,6 +162,7 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): @classmethod def setUpClass(cls): + super(GetEstimatorSpecTest, cls).setUpClass() cls._generator_optimizer = training.GradientDescentOptimizer(1.0) cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0) @@ -234,10 +234,10 @@ class StarGANEstimatorIntegrationTest(test.TestCase): # EVALUTE scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', six.iterkeys(scores)) + self.assertIn('loss', scores) self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], scores['loss']) - self.assertIn('mse_custom_metric', six.iterkeys(scores)) + self.assertIn('mse_custom_metric', scores) # PREDICT predictions = np.array([x for x in est.predict(predict_input_fn)]) diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..deb381f7be3f9545ed918813ee55aede946f22d4 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`tf.Learn` components for `TPUGANEstimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.estimator.python import tpu_gan_estimator_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.tpu_gan_estimator_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = tpu_gan_estimator_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2a22c78a304c7cc66ef069a235483e9279b3b2 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py @@ -0,0 +1,423 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A TF-GAN-backed GAN Estimator that works on TPU.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as gan_estimator_lib +from tensorflow.contrib.tpu.python.tpu import tpu_estimator +from tensorflow.contrib.tpu.python.tpu import tpu_optimizer +from tensorflow.contrib.training.python.training import training +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops.losses import losses + +__all__ = [ + 'TPUGANEstimator', +] + + +class TPUGANEstimator(tpu_estimator.TPUEstimator): + """An estimator for Generative Adversarial Networks (GANs) on TPU. + + This Estimator is backed by TFGAN. It is similar to `tfgan.GANEstimator`, + but works on TPU. + + Example: + + ```python + import tensorflow as tf + tfgan = tf.contrib.gan + + # See TFGAN's `train.py` for a description of the generator and + # discriminator API. + def generator_fn(generator_inputs): + ... + return generated_data + + def discriminator_fn(data, conditioning): + ... + return logits + + # Create GAN estimator. + config = tpu_config.RunConfig(model_dir='/my/dir') + gan_estimator = tfgan.estimator.TPUGANEstimator( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, + generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), + discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), + train_batch_size=4, + config=config) + + # Train estimator. + gan_estimator.train(train_input_fn, train_steps) + + # Evaluate resulting estimator. + gan_estimator.evaluate(eval_input_fn, eval_steps) + + # Generate samples from generator. + predictions = np.array([ + x['generated_data'] for x in gan_estimator.predict(predict_input_fn)]) + ``` + """ + + def __init__(self, + # Arguments to construct the `model_fn`. + generator_fn=None, + discriminator_fn=None, + generator_loss_fn=None, + discriminator_loss_fn=None, + generator_optimizer=None, + discriminator_optimizer=None, + get_eval_metric_ops_fn=None, + add_summaries=None, + joint_train=False, + gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1), + # TPUEstimator options. + model_dir=None, + config=None, + params=None, + use_tpu=True, + train_batch_size=None, + eval_batch_size=None, + predict_batch_size=None, + batch_axis=None, + eval_on_tpu=True, + export_to_tpu=True, + warm_start_from=None): + """Initializes a TPUGANEstimator instance. + + Args: + generator_fn: A python function that takes a Tensor, Tensor list, or + Tensor dictionary as inputs and returns the outputs of the GAN + generator. See `TFGAN` for more details and examples. Additionally, if + it has an argument called `mode`, the Estimator's `mode` will be passed + in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch + normalization. + discriminator_fn: A python function that takes the output of + `generator_fn` or real data in the GAN setup, and `generator_inputs`. + Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details + and examples. + generator_loss_fn: The loss function on the generator. Takes a `GANModel` + tuple. + discriminator_loss_fn: The loss function on the discriminator. Takes a + `GANModel` tuple. + generator_optimizer: The optimizer for generator updates, or a function + that takes no arguments and returns an optimizer. This function will + be called when the default graph is the `GANEstimator`'s graph, so + utilities like `tf.contrib.framework.get_or_create_global_step` will + work. + discriminator_optimizer: Same as `generator_optimizer`, but for the + discriminator updates. + get_eval_metric_ops_fn: A function that takes a list of arguments and + returns a dict of metric results keyed by name. The output of this + function is passed into `tf.estimator.EstimatorSpec` during evaluation. + The arguments must be: + * generator_inputs + * generated_data + * real_data + * discriminator_real_outputs + * discriminator_gen_outputs + add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. + This is ignored for jobs that run on TPU, such as the train job if + `use_tpu` is `True` or the eval job if `eval_on_tpu` is `True`. + joint_train: A Python boolean. If `True`, jointly train the generator and + the discriminator. If `False`, sequentially train them. See `train.py` + in TFGAN for more details on the differences between the two GAN + training methods. + gan_train_steps: A `tfgan.GANTrainSteps` named tuple describing the ratio + of generator to discriminator steps. For now, only supports 1:1 + training. + model_dir: Same as `TPUEstimator`: Directory to save model parameters, + graph and etc. This can also be used to load checkpoints from the + directory into a estimator to continue training a previously saved + model. If `None`, the model_dir in `config` will be used if set. If both + are set, they must be same. If both are `None`, a temporary directory + will be used. + config: Same as `TPUEstimator`: An `tpu_config.RunConfig` configuration + object. Cannot be `None`. + params: Same as `TPUEstimator`: An optional `dict` of hyper parameters + that will be passed into `input_fn` and `model_fn`. Keys are names of + parameters, values are basic python types. There are reserved keys for + `TPUEstimator`, including 'batch_size'. + use_tpu: Same as `TPUEstimator`: A bool indicating whether TPU support is + enabled. Currently, TPU training and evaluation respect this bit, but + eval_on_tpu can override execution of eval. See below. Predict still + happens on CPU. + train_batch_size: Same as `TPUEstimator`: An int representing the global + training batch size. TPUEstimator transforms this global batch size to a + per-shard batch size, as params['batch_size'], when calling `input_fn` + and `model_fn`. Cannot be `None` if `use_tpu` is `True`. Must be + divisible by total number of replicas. + eval_batch_size: Same as `TPUEstimator`: An int representing evaluation + batch size. Must be divisible by total number of replicas. + predict_batch_size: Same as `TPUEstimator`: An int representing the + prediction batch size. Must be divisible by total number of replicas. + batch_axis: Same as `TPUEstimator`: A python tuple of int values + describing how each tensor produced by the Estimator `input_fn` should + be split across the TPU compute shards. For example, if your input_fn + produced (images, labels) where the images tensor is in `HWCN` format, + your shard dimensions would be [3, 0], where 3 corresponds to the `N` + dimension of your images Tensor, and 0 corresponds to the dimension + along which to split the labels to match up with the corresponding + images. If None is supplied, and per_host_input_for_training is True, + batches will be sharded based on the major dimension. If + tpu_config.per_host_input_for_training is False or `PER_HOST_V2`, + batch_axis is ignored. + eval_on_tpu: Same as `TPUEstimator`: If False, evaluation runs on CPU or + GPU. In this case, the model_fn must return `EstimatorSpec` when called + with `mode` as `EVAL`. + export_to_tpu: Same as `TPUEstimator`: If True, `export_savedmodel()` + exports a metagraph for serving on TPU besides the one on CPU. + warm_start_from: Same as `TPUEstimator`: Optional string filepath to a + checkpoint or SavedModel to warm-start from, or a + `tf.estimator.WarmStartSettings` object to fully configure + warm-starting. If the string filepath is provided instead of a + `WarmStartSettings`, then all variables are warm-started, and it is + assumed that vocabularies and Tensor names are unchanged. + + Raises: + ValueError: If loss functions aren't callable. + ValueError: If `gan_train_steps` isn't a `tfgan_tuples.GANTrainSteps` + tuple. + ValueError: If `gan_train_steps` isn't 1:1 training. + """ + if not callable(generator_loss_fn): + raise ValueError('generator_loss_fn must be callable.') + if not callable(discriminator_loss_fn): + raise ValueError('discriminator_loss_fn must be callable.') + if not isinstance(gan_train_steps, tfgan_tuples.GANTrainSteps): + raise ValueError( + '`gan_train_steps` must be `tfgan_tuples.GANTrainSteps`. Instead, ' + 'was type: %s' % type(gan_train_steps)) + if (gan_train_steps.generator_train_steps != 1 or + gan_train_steps.discriminator_train_steps != 1): + raise ValueError('Estimator currently only supports 1:1 training.') + + if use_tpu: + generator_optimizer = _maybe_make_cross_shard_optimizer( + generator_optimizer) + discriminator_optimizer = _maybe_make_cross_shard_optimizer( + discriminator_optimizer) + + def _model_fn(features, labels, mode, params): + """GANEstimator model function.""" + del params # unused + if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT]: + raise ValueError('Mode not recognized: %s' % mode) + real_data = labels # rename inputs for clarity + generator_inputs = features # rename inputs for clarity + + # Make GANModel, which encapsulates the GAN model architectures. + # TODO(joelshor): Switch TF-GAN over to TPU-compatible summaries, then + # remove `add_summaries` logic below. + is_on_tpu = _is_on_tpu(mode, use_tpu, eval_on_tpu) + gan_model = gan_estimator_lib._get_gan_model( # pylint:disable=protected-access + mode, generator_fn, discriminator_fn, real_data, generator_inputs, + add_summaries=None if is_on_tpu else add_summaries) + + # Make the TPUEstimatorSpec, which incorporates the GANModel, losses, eval + # metrics, and optimizers (if required). + estimator_spec = _get_estimator_spec( + mode, gan_model, generator_loss_fn, discriminator_loss_fn, + get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, + joint_train, is_on_tpu, gan_train_steps) + assert isinstance(estimator_spec, tpu_estimator.TPUEstimatorSpec) + return estimator_spec + + super(TPUGANEstimator, self).__init__( + model_fn=_model_fn, + model_dir=model_dir, + config=config, + params=params, + use_tpu=use_tpu, + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + predict_batch_size=predict_batch_size, + batch_axis=batch_axis, + eval_on_tpu=eval_on_tpu, + export_to_tpu=export_to_tpu, + warm_start_from=warm_start_from) + + +def _is_on_tpu(mode, use_tpu, eval_on_tpu): + if mode == model_fn_lib.ModeKeys.TRAIN: + return use_tpu + elif mode == model_fn_lib.ModeKeys.EVAL: + return eval_on_tpu + else: + return False + + +def _get_estimator_spec( + mode, gan_model, generator_loss_fn, discriminator_loss_fn, + get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, + joint_train, is_on_tpu, gan_train_steps): + """Get the TPUEstimatorSpec for the current mode.""" + if mode == model_fn_lib.ModeKeys.PREDICT: + estimator_spec = tpu_estimator.TPUEstimatorSpec( + mode=mode, predictions={'generated_data': gan_model.generated_data}) + elif mode == model_fn_lib.ModeKeys.EVAL: + gan_loss = tfgan_tuples.GANLoss( + generator_loss=generator_loss_fn( + gan_model, add_summaries=not is_on_tpu), + discriminator_loss=discriminator_loss_fn( + gan_model, add_summaries=not is_on_tpu)) + # Eval losses for metrics must preserve batch dimension. + gan_loss_no_reduction = tfgan_tuples.GANLoss( + generator_loss=generator_loss_fn( + gan_model, add_summaries=False, reduction=losses.Reduction.NONE), + discriminator_loss=discriminator_loss_fn( + gan_model, add_summaries=False, reduction=losses.Reduction.NONE)) + estimator_spec = _get_eval_estimator_spec( + gan_model, gan_loss, gan_loss_no_reduction, get_eval_metric_ops_fn) + else: # model_fn_lib.ModeKeys.TRAIN: + gan_loss = tfgan_tuples.GANLoss( + generator_loss=generator_loss_fn( + gan_model, add_summaries=not is_on_tpu), + discriminator_loss=discriminator_loss_fn( + gan_model, add_summaries=not is_on_tpu)) + + # Construct optimizers if arguments were callable. For TPUs, they must be + # `CrossShardOptimizer`. + g_callable = callable(generator_optimizer) + gopt = generator_optimizer() if g_callable else generator_optimizer + d_callable = callable(discriminator_optimizer) + dopt = discriminator_optimizer() if d_callable else discriminator_optimizer + + estimator_spec = _get_train_estimator_spec( + gan_model, gan_loss, gopt, dopt, joint_train, gan_train_steps) + + return estimator_spec + + +def _get_eval_estimator_spec(gan_model, gan_loss, gan_loss_no_reduction, + get_eval_metric_ops_fn): + """Return an TPUEstimatorSpec for the eval case.""" + # Make the metric function and tensor names. + if get_eval_metric_ops_fn is not None: + def metric_fn( + generator_inputs, generated_data, real_data, discriminator_real_outputs, + discriminator_gen_outputs, generator_loss, discriminator_loss): + """`metric_fn` used in TPUEstimator to calculate metrics.""" + eval_metric_ops = { + 'generator_loss': metrics_lib.mean(generator_loss), + 'discriminator_loss': metrics_lib.mean(discriminator_loss), + } + custom_eval_metric_ops = get_eval_metric_ops_fn( + generator_inputs, generated_data, real_data, + discriminator_real_outputs, discriminator_gen_outputs) + if not isinstance(custom_eval_metric_ops, dict): + raise TypeError('`get_eval_metric_ops_fn` must return a dict, ' + 'received: {}'.format(custom_eval_metric_ops)) + eval_metric_ops.update(custom_eval_metric_ops) + return eval_metric_ops + tensors = { + 'generator_loss': gan_loss_no_reduction.generator_loss, + 'discriminator_loss': gan_loss_no_reduction.discriminator_loss, + 'generator_inputs': gan_model.generator_inputs, + 'generated_data': gan_model.generated_data, + 'real_data': gan_model.real_data, + 'discriminator_real_outputs': gan_model.discriminator_real_outputs, + 'discriminator_gen_outputs': gan_model.discriminator_gen_outputs, + } + else: + def metric_fn(generator_loss, discriminator_loss): + return { + 'generator_loss': metrics_lib.mean(generator_loss), + 'discriminator_loss': metrics_lib.mean(discriminator_loss), + } + tensors = { + 'generator_loss': gan_loss_no_reduction.generator_loss, + 'discriminator_loss': gan_loss_no_reduction.discriminator_loss, + } + + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + return tpu_estimator.TPUEstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, + predictions=gan_model.generated_data, + loss=scalar_loss, + eval_metrics=(metric_fn, tensors)) + + +def _get_train_estimator_spec( + gan_model, gan_loss, generator_optimizer, discriminator_optimizer, + joint_train, gan_train_steps): + """Return a TPUEstimatorSpec for the train case.""" + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + + # Get generator and discriminator update ops. We split them so that update + # ops aren't accidentally run multiple times. For now, throw an error if + # there are update ops that aren't associated with either the generator or + # the discriminator. Might modify the `kwargs` dictionary. + gen_update_ops, dis_update_ops = tfgan_train._get_update_ops( # pylint:disable=protected-access + {}, gan_model.generator_scope.name, gan_model.discriminator_scope.name) + + def gen_train_op(): + with ops.name_scope('generator_train'): + return training.create_train_op( + total_loss=gan_loss.generator_loss, + optimizer=generator_optimizer, + variables_to_train=gan_model.generator_variables, + update_ops=gen_update_ops) + def dis_train_op(): + with ops.name_scope('discriminator_train'): + return training.create_train_op( + total_loss=gan_loss.discriminator_loss, + optimizer=discriminator_optimizer, + variables_to_train=gan_model.discriminator_variables, + update_ops=dis_update_ops) + + # Either optimize the generator and discriminator sequentially or jointly. + tpu_train_op = _combine_train_ops(gen_train_op, dis_train_op, joint_train, + gan_train_steps) + + return tpu_estimator.TPUEstimatorSpec( + loss=scalar_loss, + mode=model_fn_lib.ModeKeys.TRAIN, + train_op=tpu_train_op) + + +# TODO(joelshor): Add support for multiple D / G steps. +def _combine_train_ops(gen_train_op, dis_train_op, joint_train, + gan_train_steps): + """Combine generator and discriminator train ops into a single op.""" + del gan_train_steps + if joint_train: + tpu_train_op = control_flow_ops.group(gen_train_op(), dis_train_op(), + name='joint_train') + else: + with ops.control_dependencies([dis_train_op()]): + tpu_train_op = gen_train_op() + + return tpu_train_op + + +def _maybe_make_cross_shard_optimizer(opt): + if callable(opt): + if not isinstance(opt(), tpu_optimizer.CrossShardOptimizer): + return lambda: tpu_optimizer.CrossShardOptimizer(opt()) + elif not isinstance(opt, tpu_optimizer.CrossShardOptimizer): + return tpu_optimizer.CrossShardOptimizer(opt) + return opt diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..baf2c28df4b63cff525dcf3ff880730768ad000a --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py @@ -0,0 +1,318 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TF-GAN's TPU Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib import layers +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python.estimator.python import tpu_gan_estimator_impl as estimator +from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses +from tensorflow.contrib.tpu.python.tpu import tpu_config +from tensorflow.contrib.tpu.python.tpu import tpu_estimator +from tensorflow.contrib.tpu.python.tpu import tpu_optimizer +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.estimator import WarmStartSettings +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework.errors_impl import NotFoundError +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import flags +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import learning_rate_decay +from tensorflow.python.training import training +from tensorflow.python.training import training_util + +FLAGS = flags.FLAGS + +flags.DEFINE_bool('use_tpu', False, 'Whether to run test on TPU or not.') + + +def generator_fn(noise, mode): + del mode + return layers.fully_connected(noise, tensor_shape.dimension_value( + noise.shape[1])) + + +def discriminator_fn(data, unused_conditioning, mode): + del unused_conditioning, mode + return layers.fully_connected(data, 1) + + +def get_dummy_gan_model(): + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) + with variable_scope.variable_scope('discriminator') as dis_scope: + dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) + return tfgan_tuples.GANModel( + generator_inputs=None, + generated_data=array_ops.ones([3, 4]), + generator_variables=[gen_var], + generator_scope=gen_scope, + generator_fn=None, + real_data=array_ops.zeros([3, 4]), + discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var, + discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var, + discriminator_variables=[dis_var], + discriminator_scope=dis_scope, + discriminator_fn=None) + + +def get_metrics(generator_inputs, generated_data, real_data, + discriminator_real_outputs, discriminator_gen_outputs): + del generator_inputs, discriminator_real_outputs, discriminator_gen_outputs + return { + 'mse_custom_metric': metrics_lib.mean_squared_error( + real_data, generated_data) + } + + +class GetTPUEstimatorSpecTest(test.TestCase, parameterized.TestCase): + """Tests that the EstimatorSpec is constructed appropriately.""" + + @classmethod + def setUpClass(cls): + super(GetTPUEstimatorSpecTest, cls).setUpClass() + cls._generator_optimizer = tpu_optimizer.CrossShardOptimizer( + training.GradientDescentOptimizer(1.0)) + cls._discriminator_optimizer = tpu_optimizer.CrossShardOptimizer( + training.GradientDescentOptimizer(1.0)) + + @parameterized.named_parameters( + ('joint_train', model_fn_lib.ModeKeys.TRAIN, True), + ('train_sequential', model_fn_lib.ModeKeys.TRAIN, False), + ('eval', model_fn_lib.ModeKeys.EVAL, None), + ('predict', model_fn_lib.ModeKeys.PREDICT, None)) + def test_get_estimator_spec(self, mode, joint_train): + with ops.Graph().as_default(): + self._gan_model = get_dummy_gan_model() + spec = estimator._get_estimator_spec( + mode, + self._gan_model, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + get_eval_metric_ops_fn=get_metrics, + generator_optimizer=self._generator_optimizer, + discriminator_optimizer=self._discriminator_optimizer, + joint_train=joint_train, + is_on_tpu=FLAGS.use_tpu, + gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1)) + + self.assertIsInstance(spec, tpu_estimator.TPUEstimatorSpec) + self.assertEqual(mode, spec.mode) + if mode == model_fn_lib.ModeKeys.PREDICT: + self.assertEqual({'generated_data': self._gan_model.generated_data}, + spec.predictions) + elif mode == model_fn_lib.ModeKeys.TRAIN: + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.train_op) + self.assertIsNotNone(spec.training_hooks) + elif mode == model_fn_lib.ModeKeys.EVAL: + self.assertEqual(self._gan_model.generated_data, spec.predictions) + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.eval_metrics) + + +class TPUGANEstimatorIntegrationTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + super(TPUGANEstimatorIntegrationTest, self).setUp() + self._model_dir = tempfile.mkdtemp() + self._config = tpu_config.RunConfig(model_dir=self._model_dir) + + def tearDown(self): + super(TPUGANEstimatorIntegrationTest, self).tearDown() + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow( + self, train_input_fn, eval_input_fn, predict_input_fn, prediction_size, + lr_decay=False, joint_train=True): + def make_opt(): + gstep = training_util.get_or_create_global_step() + lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) + return training.GradientDescentOptimizer(lr) + + gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + est = estimator.TPUGANEstimator( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + generator_optimizer=gopt, + discriminator_optimizer=dopt, + joint_train=joint_train, + get_eval_metric_ops_fn=get_metrics, + train_batch_size=4, + eval_batch_size=10, + predict_batch_size=8, + use_tpu=FLAGS.use_tpu, + config=self._config) + + # Train. + num_steps_train = 10 + est.train(train_input_fn, steps=num_steps_train) + + # Evaluate. + num_steps_eval = 2 + scores = est.evaluate(eval_input_fn, steps=num_steps_eval) + self.assertIn(ops.GraphKeys.GLOBAL_STEP, scores) + self.assertIn('loss', scores) + self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], + scores['loss']) + self.assertIn('mse_custom_metric', scores) + + # Predict. + predictions = np.array([x['generated_data'] for x in + est.predict(predict_input_fn)]) + self.assertAllEqual(prediction_size, predictions.shape) + + @parameterized.named_parameters( + ('joint_train', True, False, False), + ('train_sequential', False, False, False), + ('lr_decay', False, True, False), + ('train_sequential_ds', False, False, True)) + def test_numpy_input_fn(self, joint_train, lr_decay, return_ds): + """Tests complete flow with numpy_input_fn.""" + input_dim = 4 + def train_input_fn(params): + data = np.zeros([input_dim], dtype=np.float32) + ds = (dataset_ops.Dataset + .from_tensors((data, data)) + .repeat() + .batch(params['batch_size'], drop_remainder=True)) + if return_ds: + return ds + else: + x, y = ds.make_one_shot_iterator().get_next() + return x, y + def eval_input_fn(params): + data = np.zeros([input_dim], dtype=np.float32) + ds = (dataset_ops.Dataset + .from_tensors((data, data)) + .repeat() + .batch(params['batch_size'], drop_remainder=True)) + if return_ds: + return ds + else: + x, y = ds.make_one_shot_iterator().get_next() + return x, y + predict_size = 10 + def predict_input_fn(params): + del params # unused + data = np.zeros([input_dim], dtype=np.float32) + ds = (dataset_ops.Dataset + .from_tensors(data) + .repeat(predict_size) + .batch(1, drop_remainder=True)) + return ds + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + prediction_size=[predict_size, input_dim], + lr_decay=lr_decay, + joint_train=joint_train) + + +class TPUGANEstimatorWarmStartTest(test.TestCase): + + def setUp(self): + self._model_dir = self.get_temp_dir() + self._config = tpu_config.RunConfig(model_dir=self._model_dir) + self.new_variable_name = 'new_var' + self.new_variable_value = [1.0, 2.0, 3.0] + + def tearDown(self): + writer_cache.FileWriterCache.clear() + + def _test_warm_start(self, warm_start_from=None): + """Tests whether WarmStartSettings work as intended.""" + def generator_with_new_variable(noise_dict, mode): + variable_scope.get_variable(name=self.new_variable_name, + initializer=self.new_variable_value, + trainable=True) + return generator_fn(noise_dict, mode) + + est = estimator.TPUGANEstimator( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + generator_optimizer=training.GradientDescentOptimizer(1.0), + discriminator_optimizer=training.GradientDescentOptimizer(1.0), + train_batch_size=4, + use_tpu=FLAGS.use_tpu, + config=self._config) + + def train_input_fn(params): + data = np.zeros([params['batch_size'], 4], dtype=np.float32) + return data, data + + est.train(train_input_fn, steps=1) + + est_warm = estimator.TPUGANEstimator( + generator_fn=generator_with_new_variable, + discriminator_fn=discriminator_fn, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + generator_optimizer=training.GradientDescentOptimizer(1.0), + discriminator_optimizer=training.GradientDescentOptimizer(1.0), + config=tpu_config.RunConfig( + model_dir=None if warm_start_from else self._model_dir), + train_batch_size=4, + use_tpu=FLAGS.use_tpu, + warm_start_from=warm_start_from) + + est_warm.train(train_input_fn, steps=1) + + return est_warm + + def test_warm_start_error(self): + """Test if exception when reloading different estimators.""" + with self.assertRaises(NotFoundError): + self._test_warm_start() + + def test_warm_start_success(self): + """Test if GANEstimator allows explicit warm start variable assignment.""" + # Regex matches all variable names in ckpt except for new_var. + var_regex = '^(?!.*%s.*)' % self.new_variable_name + warmstart = WarmStartSettings(ckpt_to_initialize_from=self._model_dir, + vars_to_warm_start=var_regex) + est_warm = self._test_warm_start(warm_start_from=warmstart) + full_variable_name = 'Generator/%s' % self.new_variable_name + self.assertIn(full_variable_name, est_warm.get_variable_names()) + equal_vals = np.array_equal(est_warm.get_variable_value(full_variable_name), + self.new_variable_value) + self.assertTrue(equal_vals) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/eval/__init__.py b/tensorflow/contrib/gan/python/eval/__init__.py index f86b8513053a45f9830411f7df2c32d1f36a97b2..92e9abf8a35de1999eb800e169f32220fe47f8cd 100644 --- a/tensorflow/contrib/gan/python/eval/__init__.py +++ b/tensorflow/contrib/gan/python/eval/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN evaluation module. +"""TF-GAN evaluation module. This module supports techniques such as Inception Score, Frechet Inception distance, and Sliced Wasserstein distance. diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py index 1c872626a957279132772ae27df7a66a2564e9a5..a52e899114b62cb29752f72aa59f142f4a428aa1 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Model evaluation tools for TFGAN.""" +"""Model evaluation tools for TF-GAN.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index a71ee53311c1c057a5b41be0331bf56ce1a82f74..ff19ce2f78e9c86400089e454c88450f01c41764 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Model evaluation tools for TFGAN. +"""Model evaluation tools for TF-GAN. These methods come from https://arxiv.org/abs/1606.03498, https://arxiv.org/abs/1706.08500, and https://arxiv.org/abs/1801.01401. @@ -41,9 +41,9 @@ from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_ops @@ -346,7 +346,7 @@ def classifier_score(images, classifier_fn, num_batches=1): images, num_or_size_splits=num_batches) # Compute the classifier splits using the memory-efficient `map_fn`. - logits = functional_ops.map_fn( + logits = map_fn.map_fn( fn=classifier_fn, elems=array_ops.stack(generated_images_list), parallel_iterations=1, @@ -387,7 +387,7 @@ def classifier_score_from_logits(logits): # Use maximum precision for best results. logits_dtype = logits.dtype if logits_dtype != dtypes.float64: - logits = math_ops.to_double(logits) + logits = math_ops.cast(logits, dtypes.float64) p = nn_ops.softmax(logits) q = math_ops.reduce_mean(p, axis=0) @@ -505,12 +505,12 @@ def frechet_classifier_distance(real_images, # Compute the activations using the memory-efficient `map_fn`. def compute_activations(elems): - return functional_ops.map_fn(fn=classifier_fn, - elems=elems, - parallel_iterations=1, - back_prop=False, - swap_memory=True, - name='RunClassifier') + return map_fn.map_fn(fn=classifier_fn, + elems=elems, + parallel_iterations=1, + back_prop=False, + swap_memory=True, + name='RunClassifier') real_a = compute_activations(real_imgs) gen_a = compute_activations(generated_imgs) @@ -562,8 +562,8 @@ def mean_only_frechet_classifier_distance_from_activations( activations_dtype = real_activations.dtype if activations_dtype != dtypes.float64: - real_activations = math_ops.to_double(real_activations) - generated_activations = math_ops.to_double(generated_activations) + real_activations = math_ops.cast(real_activations, dtypes.float64) + generated_activations = math_ops.cast(generated_activations, dtypes.float64) # Compute means of activations. m = math_ops.reduce_mean(real_activations, 0) @@ -623,8 +623,8 @@ def diagonal_only_frechet_classifier_distance_from_activations( activations_dtype = real_activations.dtype if activations_dtype != dtypes.float64: - real_activations = math_ops.to_double(real_activations) - generated_activations = math_ops.to_double(generated_activations) + real_activations = math_ops.cast(real_activations, dtypes.float64) + generated_activations = math_ops.cast(generated_activations, dtypes.float64) # Compute mean and covariance matrices of activations. m, var = nn_impl.moments(real_activations, axes=[0]) @@ -698,15 +698,16 @@ def frechet_classifier_distance_from_activations(real_activations, activations_dtype = real_activations.dtype if activations_dtype != dtypes.float64: - real_activations = math_ops.to_double(real_activations) - generated_activations = math_ops.to_double(generated_activations) + real_activations = math_ops.cast(real_activations, dtypes.float64) + generated_activations = math_ops.cast(generated_activations, dtypes.float64) # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_activations, 0) m_w = math_ops.reduce_mean(generated_activations, 0) - num_examples_real = math_ops.to_double(array_ops.shape(real_activations)[0]) - num_examples_generated = math_ops.to_double( - array_ops.shape(generated_activations)[0]) + num_examples_real = math_ops.cast( + array_ops.shape(real_activations)[0], dtypes.float64) + num_examples_generated = math_ops.cast( + array_ops.shape(generated_activations)[0], dtypes.float64) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T real_centered = real_activations - m @@ -794,9 +795,9 @@ def kernel_classifier_distance(real_images, on a classifier. num_classifier_batches: Number of batches to split images in to in order to efficiently run them through the classifier network. - max_estimator_block_size: integer, default 1024. The distance estimator - splits samples into blocks for computational efficiency. Larger values are - more computationally expensive but decrease the variance of the distance + max_block_size: integer, default 1024. The distance estimator splits samples + into blocks for computational efficiency. Larger values are more + computationally expensive but decrease the variance of the distance estimate. dtype: if not None, coerce activations to this dtype before computations. @@ -871,9 +872,9 @@ def kernel_classifier_distance_and_std(real_images, on a classifier. num_classifier_batches: Number of batches to split images in to in order to efficiently run them through the classifier network. - max_estimator_block_size: integer, default 1024. The distance estimator - splits samples into blocks for computational efficiency. Larger values are - more computationally expensive but decrease the variance of the distance + max_block_size: integer, default 1024. The distance estimator splits samples + into blocks for computational efficiency. Larger values are more + computationally expensive but decrease the variance of the distance estimate. Having a smaller block size also gives a better estimate of the standard error. dtype: if not None, coerce activations to this dtype before computations. @@ -894,7 +895,7 @@ def kernel_classifier_distance_and_std(real_images, # Compute the activations using the memory-efficient `map_fn`. def compute_activations(elems): - return functional_ops.map_fn( + return map_fn.map_fn( fn=classifier_fn, elems=elems, parallel_iterations=1, @@ -910,7 +911,7 @@ def kernel_classifier_distance_and_std(real_images, gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) return kernel_classifier_distance_and_std_from_activations( - real_a, gen_a, max_block_size=max_block_size) + real_a, gen_a, max_block_size, dtype) kernel_inception_distance_and_std = functools.partial( @@ -967,14 +968,14 @@ def kernel_classifier_distance_from_activations(real_activations, into blocks for computational efficiency. Larger values are more computationally expensive but decrease the variance of the distance estimate. - dtype: if not None, coerce activations to this dtype before computations. + dtype: If not None, coerce activations to this dtype before computations. Returns: The Kernel Inception Distance. A floating-point scalar of the same type as the output of the activations. """ return kernel_classifier_distance_and_std_from_activations( - real_activations, generated_activations, max_block_size=max_block_size)[0] + real_activations, generated_activations, max_block_size, dtype)[0] def kernel_classifier_distance_and_std_from_activations(real_activations, @@ -1029,7 +1030,7 @@ def kernel_classifier_distance_and_std_from_activations(real_activations, computationally expensive but decrease the variance of the distance estimate. Having a smaller block size also gives a better estimate of the standard error. - dtype: if not None, coerce activations to this dtype before computations. + dtype: If not None, coerce activations to this dtype before computations. Returns: The Kernel Inception Distance. A floating-point scalar of the same type @@ -1080,7 +1081,7 @@ def kernel_classifier_distance_and_std_from_activations(real_activations, dim = math_ops.cast(real_activations.shape[1], dtype) def compute_kid_block(i): - 'Compute the ith block of the KID estimate.' + """Computes the ith block of the KID estimate.""" r_s = inds_r[i] r_e = inds_r[i + 1] r = real_activations[r_s:r_e] @@ -1098,7 +1099,7 @@ def kernel_classifier_distance_and_std_from_activations(real_activations, (math_ops.reduce_sum(k_rr) - math_ops.trace(k_rr)) / (m * (m - 1)) + (math_ops.reduce_sum(k_gg) - math_ops.trace(k_gg)) / (n * (n - 1))) - ests = functional_ops.map_fn( + ests = map_fn.map_fn( compute_kid_block, math_ops.range(n_blocks), dtype=dtype, back_prop=False) mn = math_ops.reduce_mean(ests) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index dbff1d2a367e10adc607dafb4c571bb3607a3963..bc7c1057b478fe2656898e68c1a14013b5a71d12 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN classifier_metrics.""" +"""Tests for TF-GAN classifier_metrics.""" from __future__ import absolute_import from __future__ import division @@ -234,7 +234,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): else: logits = classifier_metrics.run_inception(img, _get_dummy_graphdef()) - self.assertTrue(isinstance(logits, ops.Tensor)) + self.assertIsInstance(logits, ops.Tensor) logits.shape.assert_is_compatible_with([batch_size, 1001]) # Check that none of the model variables are trainable. @@ -258,7 +258,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): img, _get_dummy_graphdef(), output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) - self.assertTrue(isinstance(pool, ops.Tensor)) + self.assertIsInstance(pool, ops.Tensor) pool.shape.assert_is_compatible_with([batch_size, 2048]) # Check that none of the model variables are trainable. @@ -276,8 +276,8 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): classifier_metrics.INCEPTION_FINAL_POOL ]) - self.assertTrue(isinstance(logits, ops.Tensor)) - self.assertTrue(isinstance(pool, ops.Tensor)) + self.assertIsInstance(logits, ops.Tensor) + self.assertIsInstance(pool, ops.Tensor) logits.shape.assert_is_compatible_with([batch_size, 1001]) pool.shape.assert_is_compatible_with([batch_size, 2048]) @@ -290,7 +290,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): classifier_metrics.inception_score, array_ops.zeros([6, 299, 299, 3]), num_batches=3) - self.assertTrue(isinstance(score, ops.Tensor)) + self.assertIsInstance(score, ops.Tensor) score.shape.assert_has_rank(0) # Check that none of the model variables are trainable. @@ -302,7 +302,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): distance = _run_with_mock( classifier_metrics.frechet_inception_distance, img, img) - self.assertTrue(isinstance(distance, ops.Tensor)) + self.assertIsInstance(distance, ops.Tensor) distance.shape.assert_has_rank(0) # Check that none of the model variables are trainable. @@ -314,7 +314,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): distance = _run_with_mock(classifier_metrics.kernel_inception_distance, img, img) - self.assertTrue(isinstance(distance, ops.Tensor)) + self.assertIsInstance(distance, ops.Tensor) distance.shape.assert_has_rank(0) # Check that none of the model variables are trainable. @@ -365,7 +365,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): unused_image = array_ops.zeros([2, 299, 299, 3]) incscore = _run_with_mock(classifier_metrics.inception_score, unused_image) - with self.test_session(use_gpu=True) as sess: + with self.cached_session(use_gpu=True) as sess: incscore_np = sess.run(incscore, {'concat:0': logits}) self.assertAllClose(_expected_inception_score(logits), incscore_np) @@ -473,7 +473,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): classifier_fn=lambda x: x, max_block_size=600) - with self.test_session() as sess: + with self.cached_session() as sess: actual_kid, actual_std = sess.run(kid_op) expected_kid, expected_std = _expected_kid_and_std(test_pool_real_a, @@ -500,7 +500,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): max_block_size=max_block_size) for block_size in [50, 512, 1000]: - with self.test_session() as sess: + with self.cached_session() as sess: actual_kid, actual_std = sess.run(kid_op, {max_block_size: block_size}) expected_kid, expected_std = _expected_kid_and_std( diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py index 523968bed91f1021ae629bf52c405cf5c2d7b917..326fcb3cdbf2eda66207f134cd2926f09a216a99 100644 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Model evaluation tools for TFGAN.""" +"""Model evaluation tools for TF-GAN.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/python/summaries.py b/tensorflow/contrib/gan/python/eval/python/summaries.py index ecfdb39499b1e824e02415c0db1de3157e4f3216..1b202dfc97304ddc7ced42d65366aaf419439392 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Common TFGAN summaries.""" +"""Common TF-GAN summaries.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index f9995bb19d0d09eaf6fd96d039b0bba1d3a7055c..c7bbd65bbff41c25327733ae1f17a090fb69cb52 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Common TFGAN summaries.""" +"""Common TF-GAN summaries.""" from __future__ import absolute_import from __future__ import division @@ -22,7 +22,7 @@ from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python.eval.python import eval_utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import util as loss_util @@ -261,7 +261,7 @@ def add_stargan_image_summaries(stargan_model, summary.image( 'stargan_image_generation', - functional_ops.map_fn( + map_fn.map_fn( _build_image, stargan_model.input_data[:num_images], parallel_iterations=num_images, diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index 54a6f8d4d9086ad7fc8db31032677628561e48e8..53fc7cb8ede698c2d8590c7fd3016a884cef9be9 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN summaries.""" +"""Tests for TF-GAN summaries.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/features/__init__.py b/tensorflow/contrib/gan/python/features/__init__.py index 4816daf760143af9f1502873b123ffad8e5ec8ce..410c3a02052cd3a07a36a0ba332a80b3c2705d89 100644 --- a/tensorflow/contrib/gan/python/features/__init__.py +++ b/tensorflow/contrib/gan/python/features/__init__.py @@ -27,11 +27,13 @@ from __future__ import print_function from tensorflow.contrib.gan.python.features.python import clip_weights from tensorflow.contrib.gan.python.features.python import conditioning_utils from tensorflow.contrib.gan.python.features.python import random_tensor_pool +from tensorflow.contrib.gan.python.features.python import spectral_normalization from tensorflow.contrib.gan.python.features.python import virtual_batchnorm from tensorflow.contrib.gan.python.features.python.clip_weights import * from tensorflow.contrib.gan.python.features.python.conditioning_utils import * from tensorflow.contrib.gan.python.features.python.random_tensor_pool import * +from tensorflow.contrib.gan.python.features.python.spectral_normalization import * from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import * # pylint: enable=unused-import,wildcard-import @@ -40,5 +42,6 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = clip_weights.__all__ _allowed_symbols += conditioning_utils.__all__ _allowed_symbols += random_tensor_pool.__all__ +_allowed_symbols += spectral_normalization.__all__ _allowed_symbols += virtual_batchnorm.__all__ remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..54d3d0a218dec3588844333cd47e1f92489d8df9 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization.py @@ -0,0 +1,32 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras-like layers and utilities that implement Spectral Normalization. + +Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato, +et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT- +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.features.python import spectral_normalization_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.features.python.spectral_normalization_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = spectral_normalization_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc653f0a7907f407e66add5537d1e0a5adb6d8b --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py @@ -0,0 +1,315 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras-like layers and utilities that implement Spectral Normalization. + +Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato, +et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT- +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import numbers +import re + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras.engine import base_layer_utils as keras_base_layer_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging + +__all__ = [ + 'compute_spectral_norm', 'spectral_normalize', 'spectral_norm_regularizer', + 'spectral_normalization_custom_getter', 'keras_spectral_normalization' +] + +# tf.bfloat16 should work, but tf.matmul converts those to tf.float32 which then +# can't directly be assigned back to the tf.bfloat16 variable. +_OK_DTYPES_FOR_SPECTRAL_NORM = (dtypes.float16, dtypes.float32, dtypes.float64) +_PERSISTED_U_VARIABLE_SUFFIX = 'spectral_norm_u' + + +def compute_spectral_norm(w_tensor, power_iteration_rounds=1, name=None): + """Estimates the largest singular value in the weight tensor. + + Args: + w_tensor: The weight matrix whose spectral norm should be computed. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + name: An optional scope name. + + Returns: + The largest singular value (the spectral norm) of w. + """ + with variable_scope.variable_scope(name, 'spectral_norm'): + # The paper says to flatten convnet kernel weights from + # (C_out, C_in, KH, KW) to (C_out, C_in * KH * KW). But TensorFlow's Conv2D + # kernel weight shape is (KH, KW, C_in, C_out), so it should be reshaped to + # (KH * KW * C_in, C_out), and similarly for other layers that put output + # channels as last dimension. + # n.b. this means that w here is equivalent to w.T in the paper. + w = array_ops.reshape(w_tensor, (-1, w_tensor.get_shape()[-1])) + + # Persisted approximation of first left singular vector of matrix `w`. + u_var = variable_scope.get_variable( + _PERSISTED_U_VARIABLE_SUFFIX, + shape=(w.shape[0], 1), + dtype=w.dtype, + initializer=init_ops.random_normal_initializer(), + trainable=False) + u = u_var + + # Use power iteration method to approximate spectral norm. + for _ in range(power_iteration_rounds): + # `v` approximates the first right singular vector of matrix `w`. + v = nn.l2_normalize(math_ops.matmul(array_ops.transpose(w), u)) + u = nn.l2_normalize(math_ops.matmul(w, v)) + + # Update persisted approximation. + with ops.control_dependencies([u_var.assign(u, name='update_u')]): + u = array_ops.identity(u) + + u = array_ops.stop_gradient(u) + v = array_ops.stop_gradient(v) + + # Largest singular value of `w`. + spectral_norm = math_ops.matmul( + math_ops.matmul(array_ops.transpose(u), w), v) + spectral_norm.shape.assert_is_fully_defined() + spectral_norm.shape.assert_is_compatible_with([1, 1]) + + return spectral_norm[0][0] + + +def spectral_normalize(w, power_iteration_rounds=1, name=None): + """Normalizes a weight matrix by its spectral norm. + + Args: + w: The weight matrix to be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + name: An optional scope name. + + Returns: + A normalized weight matrix tensor. + """ + with variable_scope.variable_scope(name, 'spectral_normalize'): + w_normalized = w / compute_spectral_norm( + w, power_iteration_rounds=power_iteration_rounds) + return array_ops.reshape(w_normalized, w.get_shape()) + + +def spectral_norm_regularizer(scale, power_iteration_rounds=1, scope=None): + """Returns a functions that can be used to apply spectral norm regularization. + + Small spectral norms enforce a small Lipschitz constant, which is necessary + for Wasserstein GANs. + + Args: + scale: A scalar multiplier. 0.0 disables the regularizer. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + scope: An optional scope name. + + Returns: + A function with the signature `sn(weights)` that applies spectral norm + regularization. + + Raises: + ValueError: If scale is negative or if scale is not a float. + """ + if isinstance(scale, numbers.Integral): + raise ValueError('scale cannot be an integer: %s' % scale) + if isinstance(scale, numbers.Real): + if scale < 0.0: + raise ValueError( + 'Setting a scale less than 0 on a regularizer: %g' % scale) + if scale == 0.0: + logging.info('Scale of 0 disables regularizer.') + return lambda _: None + + def sn(weights, name=None): + """Applies spectral norm regularization to weights.""" + with ops.name_scope(scope, 'SpectralNormRegularizer', [weights]) as name: + scale_t = ops.convert_to_tensor( + scale, dtype=weights.dtype.base_dtype, name='scale') + return math_ops.multiply( + scale_t, + compute_spectral_norm( + weights, power_iteration_rounds=power_iteration_rounds), + name=name) + + return sn + + +def _default_name_filter(name): + """A filter function to identify common names of weight variables. + + Args: + name: The variable name. + + Returns: + Whether `name` is a standard name for a weight/kernel variables used in the + Keras, tf.layers, tf.contrib.layers or tf.contrib.slim libraries. + """ + match = re.match(r'(.*\/)?(depthwise_|pointwise_)?(weights|kernel)$', name) + return match is not None + + +def spectral_normalization_custom_getter(name_filter=_default_name_filter, + power_iteration_rounds=1): + """Custom getter that performs Spectral Normalization on a weight tensor. + + Specifically it divides the weight tensor by its largest singular value. This + is intended to stabilize GAN training, by making the discriminator satisfy a + local 1-Lipschitz constraint. + + Based on [Spectral Normalization for Generative Adversarial Networks][sn-gan]. + + [sn-gan]: https://openreview.net/forum?id=B1QRgziT- + + To reproduce an SN-GAN, apply this custom_getter to every weight tensor of + your discriminator. The last dimension of the weight tensor must be the number + of output channels. + + Apply this to layers by supplying this as the `custom_getter` of a + `tf.variable_scope`. For example: + + with tf.variable_scope('discriminator', + custom_getter=spectral_norm_getter()): + net = discriminator_fn(net) + + IMPORTANT: Keras does not respect the custom_getter supplied by the + VariableScope, so Keras users should use `keras_spectral_normalization` + instead of (or in addition to) this approach. + + It is important to carefully select to which weights you want to apply + Spectral Normalization. In general you want to normalize the kernels of + convolution and dense layers, but you do not want to normalize biases. You + also want to avoid normalizing batch normalization (and similar) variables, + but in general such layers play poorly with Spectral Normalization, since the + gamma can cancel out the normalization in other layers. By default we supply a + filter that matches the kernel variable names of the dense and convolution + layers of the tf.layers, tf.contrib.layers, tf.keras and tf.contrib.slim + libraries. If you are using anything else you'll need a custom `name_filter`. + + This custom getter internally creates a variable used to compute the spectral + norm by power iteration. It will update every time the variable is accessed, + which means the normalized discriminator weights may change slightly whilst + training the generator. Whilst unusual, this matches how the paper's authors + implement it, and in general additional rounds of power iteration can't hurt. + + Args: + name_filter: Optionally, a method that takes a Variable name as input and + returns whether this Variable should be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform per step. A higher number yeilds a better approximation of the + true spectral norm. + + Returns: + A custom getter function that applies Spectral Normalization to all + Variables whose names match `name_filter`. + + Raises: + ValueError: If name_filter is not callable. + """ + if not callable(name_filter): + raise ValueError('name_filter must be callable') + + def _internal_getter(getter, name, *args, **kwargs): + """A custom getter function that applies Spectral Normalization. + + Args: + getter: The true getter to call. + name: Name of new/existing variable, in the same format as + tf.get_variable. + *args: Other positional arguments, in the same format as tf.get_variable. + **kwargs: Keyword arguments, in the same format as tf.get_variable. + + Returns: + The return value of `getter(name, *args, **kwargs)`, spectrally + normalized. + + Raises: + ValueError: If used incorrectly, or if `dtype` is not supported. + """ + if not name_filter(name): + return getter(name, *args, **kwargs) + + if name.endswith(_PERSISTED_U_VARIABLE_SUFFIX): + raise ValueError( + 'Cannot apply Spectral Normalization to internal variables created ' + 'for Spectral Normalization. Tried to normalized variable [%s]' % + name) + + if kwargs['dtype'] not in _OK_DTYPES_FOR_SPECTRAL_NORM: + raise ValueError('Disallowed data type {}'.format(kwargs['dtype'])) + + # This layer's weight Variable/PartitionedVariable. + w_tensor = getter(name, *args, **kwargs) + + if len(w_tensor.get_shape()) < 2: + raise ValueError( + 'Spectral norm can only be applied to multi-dimensional tensors') + + return spectral_normalize( + w_tensor, + power_iteration_rounds=power_iteration_rounds, + name=(name + '/spectral_normalize')) + + return _internal_getter + + +@contextlib.contextmanager +def keras_spectral_normalization(name_filter=_default_name_filter, + power_iteration_rounds=1): + """A context manager that enables Spectral Normalization for Keras. + + Keras doesn't respect the `custom_getter` in the VariableScope, so this is a + bit of a hack to make things work. + + Usage: + with keras_spectral_normalization(): + net = discriminator_fn(net) + + Args: + name_filter: Optionally, a method that takes a Variable name as input and + returns whether this Variable should be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform per step. A higher number yeilds a better approximation of the + true spectral norm. + + Yields: + A context manager that wraps the standard Keras variable creation method + with the `spectral_normalization_custom_getter`. + """ + original_make_variable = keras_base_layer_utils.make_variable + sn_getter = spectral_normalization_custom_getter( + name_filter=name_filter, power_iteration_rounds=power_iteration_rounds) + + def make_variable_wrapper(name, *args, **kwargs): + return sn_getter(original_make_variable, name, *args, **kwargs) + + keras_base_layer_utils.make_variable = make_variable_wrapper + + yield + + keras_base_layer_utils.make_variable = original_make_variable diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea21f70ec01950cfef5e4fa851c78b219d6062f --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py @@ -0,0 +1,354 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for features.spectral_normalization.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib import slim +from tensorflow.contrib.gan.python.features.python import spectral_normalization_impl as spectral_normalization +from tensorflow.contrib.layers.python.layers import layers as contrib_layers +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras.layers import convolutional as keras_convolutional +from tensorflow.python.keras.layers import core as keras_core +from tensorflow.python.layers import convolutional as layers_convolutional +from tensorflow.python.layers import core as layers_core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class SpectralNormalizationTest(test.TestCase): + + def testComputeSpectralNorm(self): + weights = variable_scope.get_variable( + 'w', dtype=dtypes.float32, shape=[2, 3, 50, 100]) + weights = math_ops.multiply(weights, 10.0) + s = linalg_ops.svd( + array_ops.reshape(weights, [-1, weights.shape[-1]]), compute_uv=False) + true_sn = s[..., 0] + estimated_sn = spectral_normalization.compute_spectral_norm(weights) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + np_true_sn = sess.run(true_sn) + for i in range(50): + est = sess.run(estimated_sn) + if i < 1: + np_est_1 = est + if i < 4: + np_est_5 = est + if i < 9: + np_est_10 = est + np_est_50 = est + + # Check that the estimate improves with more iterations. + self.assertAlmostEqual(np_true_sn, np_est_50, 0) + self.assertGreater( + abs(np_true_sn - np_est_10), abs(np_true_sn - np_est_50)) + self.assertGreater( + abs(np_true_sn - np_est_5), abs(np_true_sn - np_est_10)) + self.assertGreater(abs(np_true_sn - np_est_1), abs(np_true_sn - np_est_5)) + + def testSpectralNormalize(self): + weights = variable_scope.get_variable( + 'w', dtype=dtypes.float32, shape=[2, 3, 50, 100]) + weights = math_ops.multiply(weights, 10.0) + normalized_weights = spectral_normalization.spectral_normalize( + weights, power_iteration_rounds=1) + + unnormalized_sigma = linalg_ops.svd( + array_ops.reshape(weights, [-1, weights.shape[-1]]), + compute_uv=False)[..., 0] + normalized_sigma = linalg_ops.svd( + array_ops.reshape(normalized_weights, [-1, weights.shape[-1]]), + compute_uv=False)[..., 0] + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + s0 = sess.run(unnormalized_sigma) + + for i in range(50): + sigma = sess.run(normalized_sigma) + if i < 1: + s1 = sigma + if i < 5: + s5 = sigma + if i < 10: + s10 = sigma + s50 = sigma + + self.assertAlmostEqual(1., s50, 0) + self.assertGreater(abs(s10 - 1.), abs(s50 - 1.)) + self.assertGreater(abs(s5 - 1.), abs(s10 - 1.)) + self.assertGreater(abs(s1 - 1.), abs(s5 - 1.)) + self.assertGreater(abs(s0 - 1.), abs(s1 - 1.)) + + def _testLayerHelper(self, build_layer_fn, w_shape, b_shape, is_keras=False): + x = array_ops.placeholder(dtypes.float32, shape=[2, 10, 10, 3]) + + w_initial = np.random.randn(*w_shape) * 10 + w_initializer = init_ops.constant_initializer(w_initial) + b_initial = np.random.randn(*b_shape) + b_initializer = init_ops.constant_initializer(b_initial) + + if is_keras: + context_manager = spectral_normalization.keras_spectral_normalization() + else: + getter = spectral_normalization.spectral_normalization_custom_getter() + context_manager = variable_scope.variable_scope('', custom_getter=getter) + + with context_manager: + (net, + expected_normalized_vars, expected_not_normalized_vars) = build_layer_fn( + x, w_initializer, b_initializer) + + x_data = np.random.rand(*x.shape) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + + # Before running a forward pass we still expect the variables values to + # differ from the initial value because of the normalizer. + w_befores = [] + for name, var in expected_normalized_vars.items(): + w_before = sess.run(var) + w_befores.append(w_before) + self.assertFalse( + np.allclose(w_initial, w_before), + msg=('%s appears not to be normalized. Before: %s After: %s' % + (name, w_initial, w_before))) + + # Not true for the unnormalized variables. + for name, var in expected_not_normalized_vars.items(): + b_before = sess.run(var) + self.assertTrue( + np.allclose(b_initial, b_before), + msg=('%s appears to be unexpectedly normalized. ' + 'Before: %s After: %s' % (name, b_initial, b_before))) + + # Run a bunch of forward passes. + for _ in range(1000): + _ = sess.run(net, feed_dict={x: x_data}) + + # We expect this to have improved the estimate of the spectral norm, + # which should have changed the variable values and brought them close + # to the true Spectral Normalized values. + _, s, _ = np.linalg.svd(w_initial.reshape([-1, 3])) + exactly_normalized = w_initial / s[0] + for w_before, (name, var) in zip(w_befores, + expected_normalized_vars.items()): + w_after = sess.run(var) + self.assertFalse( + np.allclose(w_before, w_after, rtol=1e-8, atol=1e-8), + msg=('%s did not improve over many iterations. ' + 'Before: %s After: %s' % (name, w_before, w_after))) + self.assertAllClose( + exactly_normalized, + w_after, + rtol=1e-4, + atol=1e-4, + msg=('Estimate of spectral norm for %s was innacurate. ' + 'Normalized matrices do not match.' + 'Estimate: %s Actual: %s' % (name, w_after, + exactly_normalized))) + + def testConv2D_Layers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + layer = layers_convolutional.Conv2D( + filters=3, + kernel_size=3, + padding='same', + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'tf.layers.Conv2d.kernel': layer.kernel} + expected_not_normalized_vars = {'tf.layers.Conv2d.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_ContribLayers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['CONTRIB_LAYERS_CONV2D_WEIGHTS'], + 'biases': ['CONTRIB_LAYERS_CONV2D_BIASES'] + } + net = contrib_layers.conv2d( + x, + 3, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('CONTRIB_LAYERS_CONV2D_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('CONTRIB_LAYERS_CONV2D_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'contrib.layers.conv2d.weights': weight_vars[0] + } + expected_not_normalized_vars = { + 'contrib.layers.conv2d.bias': bias_vars[0] + } + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_Slim(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['SLIM_CONV2D_WEIGHTS'], + 'biases': ['SLIM_CONV2D_BIASES'] + } + net = slim.conv2d( + x, + 3, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('SLIM_CONV2D_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('SLIM_CONV2D_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = {'slim.conv2d.weights': weight_vars[0]} + expected_not_normalized_vars = {'slim.conv2d.bias': bias_vars[0]} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_Keras(self): + + def build_layer_fn(x, w_initializer, b_initializer): + layer = keras_convolutional.Conv2D( + filters=3, + kernel_size=3, + padding='same', + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'keras.layers.Conv2d.kernel': layer.kernel} + expected_not_normalized_vars = {'keras.layers.Conv2d.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,), is_keras=True) + + def testFC_Layers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + x = layers_core.Flatten()(x) + layer = layers_core.Dense( + units=3, + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'tf.layers.Dense.kernel': layer.kernel} + expected_not_normalized_vars = {'tf.layers.Dense.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_ContribLayers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['CONTRIB_LAYERS_FC_WEIGHTS'], + 'biases': ['CONTRIB_LAYERS_FC_BIASES'] + } + x = contrib_layers.flatten(x) + net = contrib_layers.fully_connected( + x, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('CONTRIB_LAYERS_FC_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('CONTRIB_LAYERS_FC_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'contrib.layers.fully_connected.weights': weight_vars[0] + } + expected_not_normalized_vars = { + 'contrib.layers.fully_connected.bias': bias_vars[0] + } + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_Slim(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['SLIM_FC_WEIGHTS'], + 'biases': ['SLIM_FC_BIASES'] + } + x = slim.flatten(x) + net = slim.fully_connected( + x, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('SLIM_FC_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('SLIM_FC_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'slim.fully_connected.weights': weight_vars[0] + } + expected_not_normalized_vars = {'slim.fully_connected.bias': bias_vars[0]} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_Keras(self): + + def build_layer_fn(x, w_initializer, b_initializer): + x = keras_core.Flatten()(x) + layer = keras_core.Dense( + units=3, + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'keras.layers.Dense.kernel': layer.kernel} + expected_not_normalized_vars = {'keras.layers.Dense.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,), is_keras=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index a0a86c6337eefa756a209635faa70db686a36247..1f1ae2df4d6def618e86aced3296ac89c836eab7 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -28,7 +28,7 @@ wasserstein_gradient_penalty All losses must be able to accept 1D or 2D Tensors, so as to be compatible with patchGAN style losses (https://arxiv.org/abs/1611.07004). -To make these losses usable in the TFGAN framework, please create a tuple +To make these losses usable in the TF-GAN framework, please create a tuple version of the losses with `losses_utils.py`. """ @@ -38,6 +38,7 @@ from __future__ import print_function from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -69,6 +70,10 @@ __all__ = [ ] +def _to_float(tensor): + return math_ops.cast(tensor, dtypes.float32) + + # Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875). def wasserstein_generator_loss( discriminator_gen_outputs, @@ -98,7 +103,7 @@ def wasserstein_generator_loss( """ with ops.name_scope(scope, 'generator_wasserstein_loss', ( discriminator_gen_outputs, weights)) as scope: - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) loss = - discriminator_gen_outputs loss = losses.compute_weighted_loss( @@ -144,8 +149,8 @@ def wasserstein_discriminator_loss( with ops.name_scope(scope, 'discriminator_wasserstein_loss', ( discriminator_real_outputs, discriminator_gen_outputs, real_weights, generated_weights)) as scope: - discriminator_real_outputs = math_ops.to_float(discriminator_real_outputs) - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_real_outputs = _to_float(discriminator_real_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) discriminator_real_outputs.shape.assert_is_compatible_with( discriminator_gen_outputs.shape) @@ -320,7 +325,7 @@ def wasserstein_gradient_penalty( generated_data: Output of the generator. generator_inputs: Exact argument to pass to the generator, which is used as optional conditioning to the discriminator. - discriminator_fn: A discriminator function that conforms to TFGAN API. + discriminator_fn: A discriminator function that conforms to TF-GAN API. discriminator_scope: If not `None`, reuse discriminators from this scope. epsilon: A small positive number added for numerical stability when computing the gradient norm. @@ -647,7 +652,7 @@ def least_squares_generator_loss( """ with ops.name_scope(scope, 'lsq_generator_loss', (discriminator_gen_outputs, real_label)) as scope: - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) loss = math_ops.squared_difference( discriminator_gen_outputs, real_label) / 2.0 loss = losses.compute_weighted_loss( @@ -702,8 +707,8 @@ def least_squares_discriminator_loss( """ with ops.name_scope(scope, 'lsq_discriminator_loss', (discriminator_gen_outputs, real_label)) as scope: - discriminator_real_outputs = math_ops.to_float(discriminator_real_outputs) - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_real_outputs = _to_float(discriminator_real_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) discriminator_real_outputs.shape.assert_is_compatible_with( discriminator_gen_outputs.shape) diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index e3c780ac1a0f0ef15ff993bd3a9bf9730dcb45b8..44ee0f52696dc1cdcd91286a80b2d4b42be93a4d 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -403,7 +403,9 @@ class _PenaltyTest(object): def test_all_correct(self): loss = self._penalty_fn(**self._kwargs) self.assertEqual(self._expected_dtype, loss.dtype) - self.assertEqual(self._expected_op_name, loss.op.name) + # NOTE: Op names will change, it is inappropriate to include them in tests. + # See go/tf-breaking-change. + # self.assertEqual(self._expected_op_name, loss.op.name) with self.cached_session(): variables.global_variables_initializer().run() self.assertAlmostEqual(self._expected_loss, loss.eval(), 6) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index 221c70c38bd432a6be7f6cda9c6700aa2255821f..76e57df7f646547037b3461ac44f7ee5b971406c 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN utilities for loss functions that accept GANModel namedtuples. +"""TF-GAN utilities for loss functions that accept GANModel namedtuples. The losses and penalties in this file all correspond to losses in `losses_impl.py`. Losses in that file take individual arguments, whereas in this diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index 969b68449d9c82f9f9144a8657cd8932b38fd0f7..73dfee4fdeec87cf0bac5eb675fd02a64a9ad7f5 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Named tuples for TFGAN. +"""Named tuples for TF-GAN. -TFGAN training occurs in four steps, and each step communicates with the next -step via one of these named tuples. At each step, you can either use a TFGAN +TF-GAN training occurs in four steps, and each step communicates with the next +step via one of these named tuples. At each step, you can either use a TF-GAN helper function in `train.py`, or you can manually construct a tuple. """ diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 4c7bee41b33ce1fee46d374ca5fd1c0b603762f9..9bff8090d93d3ad7def69726073accfb234ef301 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The TFGAN project provides a lightweight GAN training/testing framework. +"""The TF-GAN project provides a lightweight GAN training/testing framework. This file contains the core helper functions to create and train a GAN model. See the README or examples in `tensorflow_models` for details on how to use. -TFGAN training occurs in four steps: +TF-GAN training occurs in four steps: 1) Create a model 2) Add a loss 3) Create train ops @@ -645,9 +645,10 @@ def gan_loss( type(model)) # Optionally create pooled model. - pooled_model = ( - _tensor_pool_adjusted_model(model, tensor_pool_fn) - if tensor_pool_fn else model) + if tensor_pool_fn: + pooled_model = _tensor_pool_adjusted_model(model, tensor_pool_fn) + else: + pooled_model = model # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) @@ -665,10 +666,11 @@ def gan_loss( if _use_aux_loss(mutual_information_penalty_weight): gen_info_loss = tfgan_losses.mutual_information_penalty( model, add_summaries=add_summaries) - dis_info_loss = ( - gen_info_loss - if tensor_pool_fn is None else tfgan_losses.mutual_information_penalty( - pooled_model, add_summaries=add_summaries)) + if tensor_pool_fn is None: + dis_info_loss = gen_info_loss + else: + dis_info_loss = tfgan_losses.mutual_information_penalty( + pooled_model, add_summaries=add_summaries) gen_loss += mutual_information_penalty_weight * gen_info_loss dis_loss += mutual_information_penalty_weight * dis_info_loss if _use_aux_loss(aux_cond_generator_weight): @@ -755,7 +757,9 @@ def cyclegan_loss( return namedtuples.CycleGANLoss(loss_x2y, loss_y2x) - +# Begin google-internal +# The four major parts can be found here: http://screen/tMRMBAohDYG. +# End google-internal def stargan_loss( model, generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper( @@ -774,8 +778,6 @@ def stargan_loss( add_summaries=True): """StarGAN Loss. - The four major part can be found here: http://screen/tMRMBAohDYG. - Args: model: (StarGAN) Model output of the stargan_model() function call. generator_loss_fn: The loss function on the generator. Takes a @@ -929,7 +931,7 @@ def gan_train_ops( **kwargs): """Returns GAN train ops. - The highest-level call in TFGAN. It is composed of functions that can also + The highest-level call in TF-GAN. It is composed of functions that can also be called, should a user require more control over some part of the GAN training process. diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index e534fdc17749974ebe713c2730682bea6d7a85e4..bf8b66dcfa5e44a03107cdf1ef8b04e1dbff4a9c 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -17,11 +17,6 @@ filegroup( ]), ) -load( - "//tensorflow:tensorflow.bzl", - "tf_cuda_library", -) - # For platform specific build config load( "//tensorflow/core:platform/default/build_config.bzl", @@ -37,7 +32,7 @@ tf_proto_library_cc( ], ) -tf_cuda_library( +cc_library( name = "gdr_memory_manager", srcs = ["gdr_memory_manager.cc"], hdrs = ["gdr_memory_manager.h"], @@ -58,7 +53,7 @@ tf_cuda_library( ], ) -tf_cuda_library( +cc_library( name = "gdr_worker", srcs = ["gdr_worker.cc"], hdrs = ["gdr_worker.h"], @@ -66,7 +61,6 @@ tf_cuda_library( ":gdr_memory_manager", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:gpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:graph_mgr", @@ -100,15 +94,37 @@ cc_library( ], ) +cc_library( + name = "gdr_collective_executor_mgr", + srcs = ["gdr_collective_executor_mgr.cc"], + hdrs = ["gdr_collective_executor_mgr.h"], + deps = [ + ":gdr_memory_manager", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib_internal", + "//tensorflow/core/distributed_runtime:cancellable_call", + "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", + "//tensorflow/core/distributed_runtime:device_resolver_distributed", + "//tensorflow/core/distributed_runtime:request_id", + "//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr", + "//tensorflow/core/distributed_runtime:worker_cache", + ], +) + cc_library( name = "gdr_server_lib", srcs = ["gdr_server_lib.cc"], hdrs = ["gdr_server_lib.h"], linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel deps = [ + ":gdr_collective_executor_mgr", ":gdr_memory_manager", ":gdr_rendezvous_mgr", ":gdr_worker", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", + "//tensorflow/core/distributed_runtime:device_resolver_distributed", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], alwayslink = 1, diff --git a/tensorflow/contrib/gdr/README.md b/tensorflow/contrib/gdr/README.md index 8242d93f129904828a11b61d48f2df8fb0f88bc3..711adc865f37fc84550e4b45d9f0c7fff421a0dc 100644 --- a/tensorflow/contrib/gdr/README.md +++ b/tensorflow/contrib/gdr/README.md @@ -114,7 +114,16 @@ Caveats In current implementation, only tensors that reside in host memory or in GPU memory such that the GPU is adjacent to an RDMA capable NIC will use direct RDMA as its transport. When RDMA is available but not GDR, a temporary tensor copy on host memory will be used as RDMA source/destination (and copied from/to the target device). When there is no RDMA device present, it can even fallback to the original gRPC runtime. While it is theoretically possible to mix GDR enabled TF with non-GDR deployments in the same job, make sure the environment is properly setup so the GDR mode is enabled whenever possible (i.e. do not fall back to gRPC when it is not absolutely necessary). -In the original design (as in the reference), tensor buffers are only registered to NIC when we could determine that the tensor will be either a source of Send or a sink of Recv across physical machine boundary. However, to implement the precise allocations, we need to change all the devices to possibly return a NIC compatible allocator. As GDR is currently in contrib, we would like to avoid the unnecessary code disruption to the TF core, so we allocate all tensors from NIC-registered buffers using a BFC allocator. This behaviour is similar to the effect of enabling the extra GPU option `force_gpu_compatible`, which allocate all host tensors in GPU-registered buffers no matter they will be transferred from/to GPUs or not. +In the original design (as in the reference), tensor buffers are only registered +to NIC when we could determine that the tensor will be either a source of Send +or a sink of Recv across physical machine boundary. However, to implement the +precise allocations, we need to change all the devices to possibly return a NIC +compatible allocator. As GDR is currently in contrib, we would like to avoid the +unnecessary code disruption to the TF core, so we allocate all tensors from +NIC-registered buffers using a BFC allocator. This behavior is similar to the +effect of enabling the extra GPU option `force_gpu_compatible`, which allocate +all host tensors in GPU-registered buffers no matter they will be transferred +from/to GPUs or not. Reference === diff --git a/tensorflow/contrib/gdr/gdr.proto b/tensorflow/contrib/gdr/gdr.proto index c0b89245b150bfa49cb527d25b6e1f324f353b25..bd438787c3374be6ead4f6233101fd1f548643ea 100644 --- a/tensorflow/contrib/gdr/gdr.proto +++ b/tensorflow/contrib/gdr/gdr.proto @@ -9,5 +9,4 @@ message RemoteMemoryRegion { uint64 addr = 3; uint32 rkey = 4; uint32 tensor_key = 5; - uint64 checksum = 6; } diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc new file mode 100644 index 0000000000000000000000000000000000000000..755cbdff31cd7ca31579e0d64399d681dc24ad81 --- /dev/null +++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc @@ -0,0 +1,159 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/gdr/gdr_collective_executor_mgr.h" + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/common_runtime/collective_executor_mgr.h" +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/distributed_runtime/cancellable_call.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/request_id.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/lib/random/random.h" + +namespace tensorflow { + +class WorkerCacheInterface; + +namespace { + +class RecvBufCall : public CancellableCall { + public: + RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task, + const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + const DeviceLocality& server_locality, + CancellationManager* cancel_mgr, WorkerCacheInterface* wc) + : CancellableCall(cancel_mgr, peer_task, wc) { + req_.set_step_id(step_id); + req_.set_buf_rendezvous_key(key); + *req_.mutable_client_locality() = client_locality; + *req_.mutable_server_locality() = server_locality; + req_.set_num_bytes(to_tensor->TotalBytes()); + req_.set_buf_ptr(reinterpret_cast(DMAHelper::base(to_tensor))); + req_.set_src_device(peer_device); + req_.set_dst_device(to_device->name()); + req_.set_request_id(GetUniqueRequestId()); + } + + ~RecvBufCall() override {} + + void IssueCall(const StatusCallback& done) override { + wi_->RecvBufAsync(&opts_, &req_, &resp_, done); + } + + RecvBufRequest req_; + RecvBufResponse resp_; +}; + +class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { + public: + CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr, + DeviceResolverInterface* dev_resolver, + WorkerCacheInterface* worker_cache, + int64 step_id, + RemoteMemoryManager* remote_memory_manager) + : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id), + worker_cache_(worker_cache), + remote_memory_manager_(remote_memory_manager) {} + + ~CollectiveRemoteAccessDistributed() override {} + + void RecvFromPeer(const string& peer_device, const string& peer_task, + bool peer_is_local, const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + int dev_to_dev_stream_index, + const StatusCallback& done) override { + if (peer_is_local) { + CollectiveRemoteAccessLocal::RecvFromPeer( + peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx, + to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index, + done); + return; + } + + // State that needs to be threaded through a couple of async calls + // in order to make this function completely non-blocking. + struct State { + DeviceLocality server_locality; + std::unique_ptr call; + }; + State* state = new State; + + // Logic to be executed on the RecvBufAsync callback. + auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr, + to_device_ctx, to_tensor, done](const Status& s) { + if (s.ok()) { + remote_memory_manager_->TensorFromTransportOptions( + to_tensor, state->call->resp_.transport_options(), to_device, + to_device_ctx, to_alloc_attr.on_host(), done); + } + if (!s.ok() && errors::IsFailedPrecondition(s)) { + dev_resolver_->ClearTask(peer_task); + } + + delete state; + }; + + // Logic to execute once we have the device locality for the server-side + // device. + auto dev_locality_callback = [this, state, peer_device, peer_task, key, + to_device, to_device_ctx, to_alloc_attr, + to_tensor, client_locality, + recv_buf_callback](const Status& s) { + if (!s.ok()) { + recv_buf_callback(s); + } else { + state->call.reset(new RecvBufCall( + step_id_, peer_device, peer_task, key, to_device, to_device_ctx, + to_alloc_attr, to_tensor, client_locality, state->server_locality, + &cancel_mgr_, worker_cache_)); + state->call->Start(recv_buf_callback); + } + }; + + dev_resolver_->GetLocalityAsync( + peer_device, peer_task, &state->server_locality, dev_locality_callback); + } + + void StartAbort(const Status& s) override { + CollectiveRemoteAccessLocal::StartAbort(s); + cancel_mgr_.StartCancel(); + } + + protected: + WorkerCacheInterface* worker_cache_; // Not owned + CancellationManager cancel_mgr_; + RemoteMemoryManager* remote_memory_manager_; +}; + +} // namespace + +CollectiveExecutor* GdrCollectiveExecutorMgr::Create(int64 step_id) { + CollectiveRemoteAccessDistributed* rma = + new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(), + worker_cache_, step_id, + remote_memory_manager_); + return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_, + &gpu_ring_order_); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h new file mode 100644 index 0000000000000000000000000000000000000000..1417e51e82c31035f058e8e9b546e04fb0ad97b8 --- /dev/null +++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ +#define TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ + +#include "tensorflow/contrib/gdr/gdr_memory_manager.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" +#include "tensorflow/core/framework/collective.h" + +namespace tensorflow { +class ConfigProto; +class DeviceMgr; +class WorkerCacheInterface; +class StepSequenceRequest; +class StepSequenceResponse; + +// An implementation of CollectiveExecutorMgr for a distributed environment +// that uses WorkerInterface::RecvBufAsync to route data transfers over RDMA. +class GdrCollectiveExecutorMgr : public RpcCollectiveExecutorMgr { + public: + GdrCollectiveExecutorMgr( + const ConfigProto& config, const DeviceMgr* dev_mgr, + std::unique_ptr dev_resolver, + std::unique_ptr param_resolver, + WorkerCacheInterface* worker_cache, const string& task_name, + RemoteMemoryManager* remote_memory_manager) + : RpcCollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver), + std::move(param_resolver), worker_cache, + task_name), + remote_memory_manager_(remote_memory_manager) {} + + ~GdrCollectiveExecutorMgr() override {} + + protected: + virtual CollectiveExecutor* Create(int64 step_id) override; + + private: + RemoteMemoryManager* remote_memory_manager_; // Not owned. +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index 53587fcf3050f313c85485f77ce411cba7faccff..7321e973191c4cc45f88735c6be7f2f67fe71c39 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -26,17 +26,14 @@ limitations under the License. #include #include #include -#include #include "tensorflow/contrib/gdr/gdr.pb.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/common_runtime/process_state.h" -#if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#endif // GOOGLE_CUDA +#include "tensorflow/core/common_runtime/process_state.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/numa.h" @@ -76,15 +73,14 @@ int TryToReadNumaNode(ibv_device* device) { std::ifstream ifs(filename.c_str()); string content; - CHECK(std::getline(ifs, content)); + const auto& ret = std::getline(ifs, content); + if (!ret) { + return port::kNUMANoAffinity; + } int32 value; if (strings::safe_strto32(content, &value)) { if (value < 0) { - LOG(INFO) << "Successful NUMA node read from SysFS had negative value (" - << value - << "), but there must be at least one NUMA node" - ", so returning NUMA node zero"; return port::kNUMANoAffinity; } LOG(INFO) << "NUMA node for device: " << device->name << " is " << value; @@ -114,7 +110,7 @@ class GdrMemoryManager : public RemoteMemoryManager { public: GdrMemoryManager(const string& host, const string& port); - virtual ~GdrMemoryManager(); + virtual ~GdrMemoryManager() {} virtual Status Init() override; @@ -140,7 +136,7 @@ class GdrMemoryManager : public RemoteMemoryManager { return ptr < reinterpret_cast(other->addr) + other->length; } - ibv_mr* FindMemoryRegion(void* addr, size_t length); + ibv_mr* FindMemoryRegion(const Tensor* tensor); void InsertMemoryRegion(void* addr, size_t length, const std::string& allocator_name); @@ -152,7 +148,6 @@ class GdrMemoryManager : public RemoteMemoryManager { const string port_; RdmaEndpointPtr listening_; std::atomic stopped_; - int epfd_; int numa_node_; // Server side endpoints @@ -163,15 +158,19 @@ class GdrMemoryManager : public RemoteMemoryManager { std::atomic next_key_; // Server side on-the-fly tensor buffers - mutex server_mu_; - std::map tensor_buffers_ - GUARDED_BY(server_mu_); + mutex buf_mu_; + std::map tensor_buffers_ GUARDED_BY(buf_mu_); // Client side endpoints mutex client_mu_; std::map, RdmaEndpointPtr> clients_ GUARDED_BY(client_mu_); + // Client side callbacks + mutex callback_mu_; + std::map tensor_callbacks_ + GUARDED_BY(callback_mu_); + // Managed memory regions mutex alloc_mu_; std::vector mrs_ GUARDED_BY(alloc_mu_); @@ -184,16 +183,9 @@ GdrMemoryManager::GdrMemoryManager(const string& host, const string& port) port_(port), listening_(nullptr, EndpointDeleter), stopped_(true), - next_key_(0) {} - -GdrMemoryManager::~GdrMemoryManager() { close(epfd_); } + next_key_(static_cast(random::New64())) {} Status GdrMemoryManager::Init() { - epfd_ = epoll_create1(0); - if (epfd_ == -1) { - return errors::Unavailable(strerror(errno), ": ", "epoll_create"); - } - rdma_addrinfo* addrinfo; rdma_addrinfo hints = {}; hints.ai_port_space = RDMA_PS_TCP; @@ -206,7 +198,7 @@ Status GdrMemoryManager::Init() { ibv_qp_init_attr init_attr = {}; init_attr.qp_type = IBV_QPT_RC; - init_attr.cap.max_recv_wr = 32; + init_attr.cap.max_recv_wr = 1024; init_attr.cap.max_send_wr = 1; init_attr.cap.max_recv_sge = 1; init_attr.cap.max_send_sge = 1; @@ -239,14 +231,6 @@ Status GdrMemoryManager::Init() { "cannot set server to non-blocking mode"); } - epoll_event event = {}; - event.events = EPOLLIN | EPOLLPRI; - event.data.ptr = listening_.get(); - if (epoll_ctl(epfd_, EPOLL_CTL_ADD, listening_->channel->fd, &event)) { - return errors::Unavailable(strerror(errno), ": ", - "cannot add server to epoll"); - } - numa_node_ = TryToReadNumaNode(listening_->verbs->device); SubAllocator::Visitor alloc_visitor = [this](void* ptr, int numa_node, @@ -265,121 +249,114 @@ Status GdrMemoryManager::Init() { ProcessState::singleton()->AddCPUFreeVisitor(free_visitor); LOG(INFO) << "Instrumenting CPU allocator(s)"; -#if GOOGLE_CUDA for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) { GPUProcessState::singleton()->AddCUDAHostAllocVisitor(numa_idx, alloc_visitor); GPUProcessState::singleton()->AddCUDAHostFreeVisitor(numa_idx, free_visitor); } + if (IsGDRAvailable()) { SubAllocator::Visitor cuda_alloc_visitor = [this](void* ptr, int gpu_id, size_t num_bytes) { VLOG(2) << "Registering RDMA capable memory region on GPU " << gpu_id; InsertMemoryRegion(ptr, num_bytes, strings::StrCat("GPU:", gpu_id)); }; - for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) { - GPUProcessState::singleton()->AddGPUAllocVisitor(numa_idx, - cuda_alloc_visitor); - } - VLOG(1) << "Instrumenting GPU allocator(s) for all Numas"; + GPUProcessState::singleton()->AddGPUAllocVisitor(numa_node_, + cuda_alloc_visitor); + LOG(INFO) << "Instrumenting GPU allocator for NUMA " << numa_node_; } -#endif // GOOGLE_CUDA + return Status::OK(); } void GdrMemoryManager::Run() { stopped_ = false; while (!stopped_) { - epoll_event events[32]; - int ret = epoll_wait(epfd_, events, 32, 1); - if (ret == -1) { - LOG(ERROR) << "epoll_wait: " << strerror(errno); - return; - } - for (int i = 0; i < ret; i++) { - rdma_cm_id* id = static_cast(events[i].data.ptr); - if (id == listening_.get()) { - // Accept incoming connections - if (!rdma_get_request(listening_.get(), &id)) { - if (!rdma_accept(id, nullptr)) { - LOG(INFO) << "Accepted new RDMA connection"; - if (ibv_req_notify_cq(id->recv_cq, 0)) { - LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed"; - EndpointDeleter(id); - continue; - } - for (int i = 0; i < 32; i++) { - if (rdma_post_recvv(id, nullptr, nullptr, 0)) { - LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed"; - EndpointDeleter(id); - continue; - } - } - int flags = fcntl(id->recv_cq_channel->fd, F_GETFL, 0); - if (fcntl(id->recv_cq_channel->fd, F_SETFL, flags | O_NONBLOCK)) { - LOG(ERROR) << strerror(errno) - << ": cannot set server_client to non-blocking mode"; - EndpointDeleter(id); - continue; - } - epoll_event event = {}; - event.events = EPOLLIN | EPOLLPRI; - event.data.ptr = id; - if (epoll_ctl(epfd_, EPOLL_CTL_ADD, id->recv_cq_channel->fd, - &event)) { - LOG(ERROR) << strerror(errno) - << ": cannot add server client to epoll"; - EndpointDeleter(id); - continue; - } - server_clients_.push_back({id, EndpointDeleter}); + rdma_cm_id* id = nullptr; + // Accept incoming connections + if (!rdma_get_request(listening_.get(), &id)) { + if (!rdma_accept(id, nullptr)) { + LOG(INFO) << "Accepted new RDMA connection"; + for (int i = 0; i < 1024; i++) { + if (rdma_post_recvv(id, nullptr, nullptr, 0)) { + LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed"; + EndpointDeleter(id); + continue; } } - } else { - // Polling work completions - ibv_cq* cq; - void* context; - if (!ibv_get_cq_event(id->recv_cq_channel, &cq, &context)) { - ibv_ack_cq_events(id->recv_cq, 1); - if (ibv_req_notify_cq(id->recv_cq, 0)) { - LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed"; - continue; + server_clients_.push_back({id, EndpointDeleter}); + } + } + // Polling server side work completions + for (const auto& client : server_clients_) { + ibv_wc wc[32]; + int ret = ibv_poll_cq(client->recv_cq, 32, wc); + if (ret < 0) { + LOG(ERROR) << "ibv_poll_cq failed"; + continue; + } + for (int i = 0; i < ret; i++) { + if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) { + LOG(ERROR) << "Received unknown operation " << wc[i].opcode; + } + if (wc[i].status != 0) { + LOG(ERROR) << ibv_wc_status_str(wc[i].status); + } + TensorKey tensor_key = ntohl(wc[i].imm_data); + + if (rdma_post_recvv(client.get(), nullptr, nullptr, 0)) { + perror("rdma_post_recvv"); + LOG(ERROR) << "rdma_post_recvv failed"; + } + + mutex_lock l(buf_mu_); + auto iter = tensor_buffers_.find(tensor_key); + if (iter == std::end(tensor_buffers_)) { + LOG(ERROR) << "Cannot find tensor buffer for tensor key " + << tensor_key; + } else { + const TensorBuffer* buffer = iter->second; + buffer->Unref(); + tensor_buffers_.erase(iter); + } + } + } + // Polling client side work completions + if (client_mu_.try_lock()) { + for (const auto& client : clients_) { + ibv_wc wc[32]; + int ret = ibv_poll_cq(client.second->send_cq, 32, wc); + for (int i = 0; i < ret; i++) { + Status s; + if (wc[i].status) { + s = errors::Unavailable(ibv_wc_status_str(wc[i].status)); + } else { + s = Status::OK(); } - ibv_wc wc[32]; - int ret = ibv_poll_cq(id->recv_cq, 32, wc); - if (ret < 0) { - LOG(ERROR) << "ibv_poll_cq failed"; - continue; + TensorKey key = wc[i].wr_id; + + ibv_send_wr wr = {}; + wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wr.imm_data = htonl(key); + ibv_send_wr* bad_wr; + if (ibv_post_send(client.second->qp, &wr, &bad_wr)) { + LOG(ERROR) << strerror(errno) + << ": ibv_post_send failed for tensor_key " << key; } - for (int i = 0; i < ret; i++) { - if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) { - LOG(ERROR) << "Received unknown operation " << wc[i].opcode; - } - if (wc[i].status != 0) { - LOG(ERROR) << ibv_wc_status_str(wc[i].status); - } - TensorKey tensor_key = ntohl(wc[i].imm_data); - { - mutex_lock l(server_mu_); - auto iter = tensor_buffers_.find(tensor_key); - if (iter == std::end(tensor_buffers_)) { - LOG(ERROR) << "Cannot find tensor buffer for tensor key " - << tensor_key; - } else { - const TensorBuffer* buffer = iter->second; - buffer->Unref(); - tensor_buffers_.erase(iter); - } - } - if (rdma_post_recvv(id, nullptr, nullptr, 0)) { - perror("rdma_post_recvv"); - LOG(ERROR) << "rdma_post_recvv failed"; - continue; - } + + mutex_lock l(callback_mu_); + auto iter = tensor_callbacks_.find(key); + if (iter != std::end(tensor_callbacks_)) { + iter->second(s); + tensor_callbacks_.erase(iter); + } else { + LOG(WARNING) << "Cannot find client callback with tensor key " + << key; } } } + client_mu_.unlock(); } } } @@ -390,116 +367,58 @@ void GdrMemoryManager::TransportOptionsFromTensor( ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor, Device* device, DeviceContext* device_context, bool on_host, StatusCallback done) { - auto buffer = DMAHelper::buffer(&tensor); - void* addr = buffer->data(); - size_t length = buffer->size(); - if (length == 0) { - done(errors::Unavailable("Cannot register tensor buffer of size 0")); - return; - } + ibv_mr* mr = FindMemoryRegion(&tensor); + const TensorBuffer* buffer = DMAHelper::buffer(&tensor); - ibv_mr* mr = FindMemoryRegion(addr, length); - -#if GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); - Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); - GPUUtil::CopyGPUTensorToCPU( - device, device_context, &tensor, host_copy, - [done, host_copy, mutable_transport_options, this](const Status& s) { - if (!s.ok()) { - done(s); - delete host_copy; - return; - } - auto buffer = DMAHelper::buffer(host_copy); - void* addr = buffer->data(); - size_t length = buffer->size(); - ibv_mr* mr = FindMemoryRegion(addr, length); - - if (mr == nullptr) { - done(errors::Unavailable("Cannot find pinned memory region")); - delete host_copy; - return; - } - - buffer->Ref(); - TensorKey tensor_key = next_key_++; - { - mutex_lock l(server_mu_); - tensor_buffers_.insert(std::make_pair(tensor_key, buffer)); - } - - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { - checksum = GPUUtil::Checksum(*host_copy); - } - - RemoteMemoryRegion remote_mr; - remote_mr.set_host(host_); - remote_mr.set_port(port_); - remote_mr.set_addr(reinterpret_cast(addr)); - remote_mr.set_rkey(mr->rkey); - remote_mr.set_tensor_key(tensor_key); - remote_mr.set_checksum(checksum); - mutable_transport_options->PackFrom(remote_mr); - - done(Status::OK()); - delete host_copy; - }); - return; - } -#endif + Tensor* copy = nullptr; if (mr == nullptr) { - Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_); - Tensor host_copy(alloc, tensor.dtype(), tensor.shape()); - - std::memcpy(DMAHelper::buffer(&host_copy)->data(), buffer->data(), length); - VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer"; - - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - - mr = FindMemoryRegion(addr, length); + AllocatorAttributes alloc_attrs; + alloc_attrs.set_gpu_compatible(true); + alloc_attrs.set_nic_compatible(true); + alloc_attrs.set_on_host(true); + Allocator* alloc = device->GetAllocator(alloc_attrs); + copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); + + mr = FindMemoryRegion(copy); + buffer = DMAHelper::buffer(copy); if (mr == nullptr) { done(errors::Unavailable("Cannot find pinned memory region")); + delete copy; return; } - - buffer->Ref(); - } else { - buffer->Ref(); } TensorKey tensor_key = next_key_++; + buffer->Ref(); { - mutex_lock l(server_mu_); + mutex_lock l(buf_mu_); tensor_buffers_.insert(std::make_pair(tensor_key, buffer)); } - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { -#ifdef GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - checksum = GPUUtil::Checksum(device, device_context, tensor); - } else { - checksum = GPUUtil::Checksum(tensor); - } -#endif - } - RemoteMemoryRegion remote_mr; remote_mr.set_host(host_); remote_mr.set_port(port_); - remote_mr.set_addr(reinterpret_cast(addr)); + remote_mr.set_addr(reinterpret_cast(buffer->data())); remote_mr.set_rkey(mr->rkey); remote_mr.set_tensor_key(tensor_key); - remote_mr.set_checksum(checksum); mutable_transport_options->PackFrom(remote_mr); - done(Status::OK()); + if (copy && device->tensorflow_gpu_device_info() && !on_host) { + device_context->CopyDeviceTensorToCPU(&tensor, "" /* tensor_name */, device, + copy, [done, copy](const Status& s) { + done(s); + delete copy; + }); + return; + } else if (copy) { + std::memcpy(buffer->data(), DMAHelper::buffer(&tensor)->data(), + buffer->size()); + done(Status::OK()); + delete copy; // OK to delete; we have reffed the underlying TensorBuffer + } else { + done(Status::OK()); + } } void GdrMemoryManager::TensorFromTransportOptions( @@ -512,42 +431,10 @@ void GdrMemoryManager::TensorFromTransportOptions( return; } - auto buffer = DMAHelper::buffer(tensor); - void* addr = buffer->data(); - size_t length = buffer->size(); - ibv_mr* mr = FindMemoryRegion(addr, length); - - Tensor host_copy; -#if GOOGLE_CUDA - if (mr == nullptr && !on_host) { - Allocator* alloc = - GPUProcessState::singleton()->GetCUDAHostAllocator(numa_node_); - host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - mr = FindMemoryRegion(addr, length); - } -#endif // GOOGLE_CUDA - - if (mr == nullptr) { - Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_); - host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); - - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - - mr = FindMemoryRegion(addr, length); - if (mr == nullptr) { - done(errors::Unavailable("Cannot find pinned memory region")); - return; - } - } - - decltype(clients_)::iterator iter; - bool success; + rdma_cm_id* id = nullptr; { + decltype(clients_)::iterator iter; + bool success; mutex_lock l(client_mu_); std::tie(iter, success) = clients_.insert( std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()), @@ -560,93 +447,94 @@ void GdrMemoryManager::TensorFromTransportOptions( return; } } - } - rdma_cm_id* id = iter->second.get(); - - uint64_t start = Env::Default()->NowMicros(); - - if (rdma_post_read(id, nullptr, buffer->data(), buffer->size(), mr, 0, - remote_mr.addr(), remote_mr.rkey())) { - done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed")); - return; + id = iter->second.get(); } - ibv_send_wr wr = {}; - wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wr.imm_data = htonl(remote_mr.tensor_key()); - wr.send_flags = IBV_SEND_SIGNALED; - ibv_send_wr* bad_wr; - if (ibv_post_send(id->qp, &wr, &bad_wr)) { - done(errors::Unavailable(strerror(errno), ": ", "ibv_post_send failed")); - return; - } + ibv_mr* mr = FindMemoryRegion(tensor); + const TensorBuffer* buffer = DMAHelper::buffer(tensor); - ibv_wc wc = {}; - int ret; - while ((ret = ibv_poll_cq(id->send_cq, 1, &wc)) == 0) - ; - if (ret < 0 || wc.status) { - done(errors::Unavailable(ibv_wc_status_str(wc.status))); - return; - } + const Tensor* copy = nullptr; -#if GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host && - host_copy.NumElements() > 0) { - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { - checksum = GPUUtil::Checksum(host_copy); - CHECK(checksum == remote_mr.checksum()) - << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); + if (mr == nullptr) { + AllocatorAttributes alloc_attrs; + alloc_attrs.set_gpu_compatible(true); + alloc_attrs.set_nic_compatible(true); + alloc_attrs.set_on_host(true); + Allocator* alloc = device->GetAllocator(alloc_attrs); + copy = new Tensor(alloc, tensor->dtype(), tensor->shape()); + + mr = FindMemoryRegion(copy); + buffer = DMAHelper::buffer(copy); + if (mr == nullptr) { + done(errors::Unavailable("Cannot find pinned memory region")); + delete copy; + return; } - Tensor* ref = new Tensor; - std::swap(host_copy, *ref); - GPUUtil::CopyCPUTensorToGPU( - ref, device_context, device, tensor, - [ref, done, buffer, remote_mr, start](const Status& s) { - if (!s.ok()) { - done(s); - delete ref; - return; - } - uint64_t end = Env::Default()->NowMicros(); - - VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey() - << " of size " << buffer->size() << " with tensor key " - << remote_mr.tensor_key() << " took " << (end - start) - << " micros"; - done(Status::OK()); - delete ref; - }); - return; } -#endif // GOOGLE_CUDA - if ((on_host || !device->tensorflow_gpu_device_info()) && - host_copy.NumElements() > 0) { - std::memcpy(DMAHelper::buffer(tensor)->data(), addr, length); - VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer"; - } + uint64_t start = Env::Default()->NowMicros(); - uint64_t end = Env::Default()->NowMicros(); + TensorKey tensor_key = remote_mr.tensor_key(); - VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey() - << " of size " << buffer->size() << " with tensor key " - << remote_mr.tensor_key() << " took " << (end - start) << " micros"; + StatusCallback callback = [done, copy, device, device_context, on_host, + tensor, start, tensor_key](const Status& s) { + if (!s.ok()) { + done(s); + if (copy) { + delete copy; + } + return; + } - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { -#ifdef GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - checksum = GPUUtil::Checksum(device, device_context, *tensor); + VLOG(2) << "RDMA of tensor " << tensor_key << " of size " + << DMAHelper::buffer(tensor)->size() << " took " + << (Env::Default()->NowMicros() - start) << " micros"; + + if (copy && device->tensorflow_gpu_device_info() && !on_host) { + device_context->CopyCPUTensorToDevice(copy, device, tensor, + [done, copy](const Status& s) { + done(s); + delete copy; + }); + } else if (copy) { + std::memcpy(DMAHelper::buffer(tensor)->data(), + DMAHelper::buffer(copy)->data(), + DMAHelper::buffer(copy)->size()); + done(s); + delete copy; } else { - checksum = GPUUtil::Checksum(*tensor); + done(s); + } + }; + + { + mutex_lock l(callback_mu_); + if (tensor_callbacks_.find(tensor_key) == std::end(tensor_callbacks_)) { + tensor_callbacks_.insert(std::make_pair(tensor_key, std::move(callback))); + } else { + done(errors::Unavailable("Received duplicated tensor key")); + if (copy) { + delete copy; + } + return; + } + } + + if (rdma_post_read(id, reinterpret_cast(tensor_key), buffer->data(), + buffer->size(), mr, IBV_SEND_SIGNALED, remote_mr.addr(), + remote_mr.rkey())) { + done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed")); + { + mutex_lock l(callback_mu_); + auto iter = tensor_callbacks_.find(tensor_key); + if (iter != std::end(tensor_callbacks_)) { + tensor_callbacks_.erase(iter); + } + } + if (copy) { + delete copy; } - CHECK(checksum == remote_mr.checksum()) - << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); -#endif } - done(Status::OK()); } Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, @@ -663,7 +551,7 @@ Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, ibv_qp_init_attr init_attr = {}; init_attr.qp_type = IBV_QPT_RC; init_attr.cap.max_recv_wr = 1; - init_attr.cap.max_send_wr = 32; + init_attr.cap.max_send_wr = 1024; init_attr.cap.max_recv_sge = 1; init_attr.cap.max_send_sge = 1; @@ -687,8 +575,8 @@ Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, return Status::OK(); } -ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) { - if (length == 0) return nullptr; +ibv_mr* GdrMemoryManager::FindMemoryRegion(const Tensor* tensor) { + const void* addr = DMAHelper::buffer(tensor)->data(); mutex_lock l(alloc_mu_); auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); if (iter == std::end(mrs_) || iter->get()->addr > addr) { diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index fbccbead03fc0d641db40ede661bf3677d44c45d..1124dff741309d8fd04954e70c5ebaaf164b940a 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -58,11 +58,9 @@ class GdrRecvTensorCall : public BaseRecvTensorCall { resp_.InitAlloc(dst_device_, recv_args_.alloc_attrs); StatusCallback cb = [this, recv_done](const Status& s) { bool dma_ok = resp_.metadata().has_transport_options(); - if (s.ok() && tensor().TotalBytes() > 0 && (!is_dead()) && dma_ok) { + if (s.ok() && tensor().TotalBytes() > 1024 && (!is_dead()) && dma_ok) { auto transport_options = resp_.metadata().transport_options(); - const bool on_host = - (dst_device_->tensorflow_gpu_device_info() == nullptr) || - recv_args_.alloc_attrs.on_host(); + const bool on_host = recv_args_.alloc_attrs.on_host(); remote_memory_manager_->TensorFromTransportOptions( const_cast(&tensor()), transport_options, dst_device_, recv_args_.device_context, on_host, @@ -70,9 +68,6 @@ class GdrRecvTensorCall : public BaseRecvTensorCall { if (!s.ok()) { mutex_lock l(mu_); status_.Update(s); - LOG(ERROR) << "Cannot find pinned memory region from allocator " - << dst_device_->GetAllocator(recv_args_.alloc_attrs) - ->Name(); } recv_done(); }); @@ -172,8 +167,11 @@ class GdrRemoteRendezvous : public BaseRemoteRendezvous { // RendezvousMgr already aborted, shouldn't send RPC call any more if (!call->status().ok()) { - done(call->status(), Args(), Args(), Tensor(), false); + // NOTE: `*session()` can potentially be deleted before we return from + // `call->done()(...)`, so we must release the worker before calling the + // callback. session()->worker_cache->ReleaseWorker(src_worker, rwi); + done(call->status(), Args(), Args(), Tensor(), false); delete call; return; } @@ -186,8 +184,11 @@ class GdrRemoteRendezvous : public BaseRemoteRendezvous { // If StartAbort was called prior to DeregisterCall, then the // current status should be bad. Status s = call->status(); - done(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); + // NOTE: `*session()` can potentially be deleted before we return from + // `call->done()(...)`, so we must release the worker before calling the + // callback. session()->worker_cache->ReleaseWorker(src_worker, rwi); + done(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); delete call; Unref(); }); diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index b3f48ec1dd9c75055f4e1ea76eb203b6ccf94718..c39cc0f9bcecc26aedfaf9707113210acf670244 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -16,11 +16,13 @@ limitations under the License. #include "tensorflow/contrib/gdr/gdr_server_lib.h" #include "grpc/support/alloc.h" +#include "tensorflow/contrib/gdr/gdr_collective_executor_mgr.h" #include "tensorflow/contrib/gdr/gdr_memory_manager.h" #include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h" #include "tensorflow/contrib/gdr/gdr_worker.h" - -#include "grpc/support/alloc.h" +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" namespace tensorflow { @@ -57,10 +59,34 @@ Status GdrServer::Init() { return std::unique_ptr( new GdrWorker(env, config, remote_memory_manager_.get())); }; - + CollectiveMgrCreationFunction collective_mgr_func = + [this](const ConfigProto& config, const WorkerEnv* env, + WorkerCacheInterface* worker_cache) { + string unused; + string default_worker_name; + DeviceNameUtils::SplitDeviceName( + env->device_mgr->ListDevices()[0]->name(), &default_worker_name, + &unused); + + std::unique_ptr dev_resolver( + new DeviceResolverDistributed(env->device_mgr, worker_cache, + default_worker_name)); + std::unique_ptr param_resolver( + new CollectiveParamResolverDistributed( + config, env->device_mgr, dev_resolver.get(), worker_cache, + default_worker_name)); + return new GdrCollectiveExecutorMgr( + config, env->device_mgr, std::move(dev_resolver), + std::move(param_resolver), worker_cache, default_worker_name, + remote_memory_manager_.get()); + }; TF_RETURN_IF_ERROR(remote_memory_manager_->Init()); - return GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func); + GrpcServerOptions opts; + opts.rendezvous_mgr_func = rendezvous_mgr_func; + opts.collective_mgr_func = collective_mgr_func; + opts.worker_func = worker_func; + return GrpcServer::Init(opts); } Status GdrServer::Start() { @@ -74,9 +100,8 @@ Status GdrServer::Start() { } Status GdrServer::Stop() { - TF_RETURN_IF_ERROR(GrpcServer::Stop()); remote_memory_manager_->Stop(); - return Status::OK(); + return GrpcServer::Stop(); } Status GdrServer::Join() { diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index 867cb83f42034c8e9061e333ea671457745f92c3..1204b8ca501a8f99ea6abd6c047ab2d91350bae1 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -15,12 +15,10 @@ limitations under the License. #include "tensorflow/contrib/gdr/gdr_worker.h" +#include "tensorflow/core/common_runtime/buf_rendezvous.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#if GOOGLE_CUDA -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#endif // GOOGLE_CUDA #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h" @@ -32,6 +30,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_session.h" #include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -43,13 +42,13 @@ GdrWorker::GdrWorker(WorkerEnv* worker_env, const ConfigProto& config, RemoteMemoryManager* remote_memory_manager) : GrpcWorker(worker_env, config), remote_memory_manager_(remote_memory_manager), - recv_tensor_recent_request_ids_(100000) {} + recent_request_ids_(100000) {} void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, ::grpc::ByteBuffer* response, StatusCallback done) { - Status s = recv_tensor_recent_request_ids_.TrackUnique( + Status s = recent_request_ids_.TrackUnique( request->request_id(), "RecvTensor (GdrWorker)", *request); if (!s.ok()) { done(s); @@ -78,7 +77,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, const bool dma_ok = request->dma_ok(); env_->rendezvous_mgr->RecvLocalAsync( step_id, parsed, - [this, opts, response, done, src_dev, dma_ok]( + [this, opts, response, done, src_dev, request, dma_ok]( const Status& status, const Rendezvous::Args& send_args, const Rendezvous::Args&, const Tensor& val, const bool is_dead) { opts->ClearCancelCallback(); @@ -89,10 +88,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, // 3) the tensor has the on_host allocation attribute, // i.e. it's in CPU RAM *independent of its assigned // device type*. - const bool on_host = - (src_dev->tensorflow_gpu_device_info() == nullptr) || - send_args.alloc_attrs.on_host(); - if (val.TotalBytes() > 0 && (!is_dead) && + const bool on_host = send_args.alloc_attrs.on_host(); + if (val.TotalBytes() > 1024 && (!is_dead) && DMAHelper::CanUseDMA(&val) && dma_ok) { // DMA cases. RecvTensorResponse* proto = new RecvTensorResponse; @@ -117,8 +114,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, } else { // Non-DMA cases. if (src_dev->tensorflow_gpu_device_info() && (!on_host)) { -#if GOOGLE_CUDA - const DeviceContext* send_dev_context = send_args.device_context; + DeviceContext* send_dev_context = send_args.device_context; AllocatorAttributes alloc_attrs; alloc_attrs.set_gpu_compatible(true); alloc_attrs.set_on_host(true); @@ -127,7 +123,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, CHECK(send_dev_context) << "send dev name: " << src_dev->name() << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); - // "val" is on a GPU. Uses GPUUtil to fill the response proto. + // "val" is on an accelerator device. Uses the device_context to + // fill the copy on host. StatusCallback copy_ready = [response, done, copy, is_dead](const Status& s) { // The value is now ready to be returned on the wire. @@ -136,11 +133,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, delete copy; }; - GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy, - copy_ready); -#else - done(errors::Internal("No GPU device in process")); -#endif // GOOGLE_CUDA + send_dev_context->CopyDeviceTensorToCPU( + &val, request->rendezvous_key(), src_dev, copy, copy_ready); } else { grpc::EncodeTensorToByteBuffer(is_dead, val, response); done(Status::OK()); @@ -153,4 +147,41 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, }); } +void GdrWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) { + // This is an RDMA enabled implementation augmenting grpc. + Status s = recent_request_ids_.TrackUnique(request->request_id(), + "RecvBuf (GdrWorker)", *request); + if (!s.ok()) { + done(s); + return; + } + CollectiveExecutor::Handle ce_handle( + env_->collective_executor_mgr->FindOrCreate(request->step_id()), true); + CollectiveRemoteAccess* rma = ce_handle.get()->remote_access(); + rma->buf_rendezvous()->ConsumeBuf( + request->buf_rendezvous_key(), + [this, request, response, done](const Status& status, + BufRendezvous::Hook* hook) { + Status s = status; + if (s.ok()) { + if (!DMAHelper::CanUseDMA(hook->prod_value)) { + s = errors::Internal("Tensor value for key ", + request->buf_rendezvous_key(), + " is not of a type supported by RecvBuf"); + } + } + if (s.ok()) { + remote_memory_manager_->TransportOptionsFromTensor( + response->mutable_transport_options(), *hook->prod_value, + hook->prod_dev, hook->prod_ctx, hook->prod_attr.on_host(), + [this, response, done, hook](const Status& s) { + response->set_send_start_micros(env_->env->NowMicros()); + done(s); + BufRendezvous::DoneWithHook(hook); + }); + } + }); +} + } // namespace tensorflow diff --git a/tensorflow/contrib/gdr/gdr_worker.h b/tensorflow/contrib/gdr/gdr_worker.h index 39f11e6bde5a1ca7ae91ead02279d22d70af027b..9a85cfd4263ad86f6579eedce95969c2829ff62c 100644 --- a/tensorflow/contrib/gdr/gdr_worker.h +++ b/tensorflow/contrib/gdr/gdr_worker.h @@ -38,9 +38,13 @@ class GdrWorker : public GrpcWorker { ::grpc::ByteBuffer* response, StatusCallback done) override; + virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, + StatusCallback done) override; + private: RemoteMemoryManager* remote_memory_manager_; // Not owned - RecentRequestIds recv_tensor_recent_request_ids_; + RecentRequestIds recent_request_ids_; }; } // namespace tensorflow diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index e79ccd8da1f8952758ae322d3a92dec34910a9db..5b37239665d46db38fc249e9004d2200abb3d610 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -22,7 +22,6 @@ from __future__ import print_function from copy import deepcopy from functools import partial from six import iteritems -from six import iterkeys from six import string_types from six import StringIO from tensorflow.contrib.graph_editor import reroute @@ -735,9 +734,8 @@ def graph_replace(target_ts, replacement_ts, dst_scope="", # control dependencies. graph = util.get_unique_graph(flatten_target_ts, check_types=(tf_ops.Tensor)) control_ios = util.ControlOutputs(graph) - ops = select.get_walks_intersection_ops(list(iterkeys(replacement_ts)), - flatten_target_ts, - control_ios=control_ios) + ops = select.get_walks_intersection_ops( + list(replacement_ts), flatten_target_ts, control_ios=control_ios) if not ops: raise ValueError("Targets and replacements are not connected!") diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py index 5c5599858ee6879a5703d65658bf4bbd881c7e72..71eac729a8a81c2f59f9ed5d7f42fb7b1c3e1b5c 100644 --- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py +++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py @@ -23,11 +23,16 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation class SequenceFileDataset(dataset_ops.DatasetSource): """A Sequence File Dataset that reads the sequence file.""" + @deprecation.deprecated( + None, + "tf.contrib.hadoop will be removed in 2.0, the support for Apache Hadoop " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, filenames): """Create a `SequenceFileDataset`. @@ -50,13 +55,11 @@ class SequenceFileDataset(dataset_ops.DatasetSource): Args: filenames: A `tf.string` tensor containing one or more filenames. """ - super(SequenceFileDataset, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") - - def _as_variant_tensor(self): - return gen_dataset_ops.sequence_file_dataset( + variant_tensor = gen_dataset_ops.sequence_file_dataset( self._filenames, self._element_structure._flat_types) # pylint: disable=protected-access + super(SequenceFileDataset, self).__init__(variant_tensor) @property def _element_structure(self): diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD index 0081fb61770075a2c36e92f65e01126f657edeb4..92016e6a83975a9b15a39a15125e0eabc111912e 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD @@ -16,9 +16,31 @@ tf_cc_binary( srcs = ["hvx_ops_support_checker_main.cc"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:candidate_sampling_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", + "//tensorflow/core:data_flow_ops_op_lib", "//tensorflow/core:framework_internal", + "//tensorflow/core:functional_ops_op_lib", + "//tensorflow/core:io_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core:list_ops_op_lib", + "//tensorflow/core:logging_ops_op_lib", + "//tensorflow/core:lookup_ops_op_lib", + "//tensorflow/core:manip_ops_op_lib", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", + "//tensorflow/core:parsing_ops_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:random_ops_op_lib", + "//tensorflow/core:remote_fused_graph_ops_op_lib", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:sparse_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", + "//tensorflow/core:string_ops_op_lib", + "//tensorflow/core:training_ops_op_lib", + "//tensorflow/core:user_ops_op_lib", "//tensorflow/core/kernels:remote_fused_graph_execute_utils", "//tensorflow/core/kernels/hexagon:graph_transferer", "//tensorflow/tools/graph_transforms:file_utils", diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md index 5a8c650fb927be0c835aaceffc516c048195c7bf..c1f6cac4942436d32f9867d4b5557c6b9e376c69 100644 --- a/tensorflow/contrib/ignite/README.md +++ b/tensorflow/contrib/ignite/README.md @@ -30,7 +30,8 @@ system based on Apache Ignite. ## Features -Ignite Dataset provides features that that you can use in a wide range of cases. The most important and interesting features are described below. +Ignite Dataset provides features that you can use in a wide range of cases. The +most important and interesting features are described below. ### Distributed In-Memory Datasource [Apache Ignite](https://ignite.apache.org/) is a distributed in-memory database, caching, and processing platform that provides fast data access. It allows you to avoid limitations of hard drive and store and operate with as much data as you need in distributed cluster. You can utilize @@ -97,6 +98,7 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset +>>> tf.enable_eager_execution() >>> >>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels']) >>> @@ -116,7 +118,15 @@ Using this ability we can calculate gradients on the nodes the data is stored on Apache Ignite uses horizontal partitioning to store data in distributed cluster. When we create Apache Ignite cache (or table in terms of SQL), we can specify the number of partitions the data will be partitioned on. For example, if an Apache Ignite cluster consists of 10 machines and we create cache with 10 partitions, then every machine will maintain approximately one data partition. -Ignite Dataset allows using these two aspects of distributed neural network training (using TensorFlow) and Apache Ignite partitioning. Ignite Dataset is a computation graph operation that can be performed on a remote worker. The remote worker can override Ignite Dataset parameters (such as `host`, `port` or `part`) by setting correstondent environment variables for worker process (such as `IGNITE_DATASET_HOST`, `IGNITE_DATASET_PORT` or `IGNITE_DATASET_PART`). Using this overriding approach, we can assign a specific partition to every worker so that one worker handles one partition and, at the same time, transparently work with single dataset. +Ignite Dataset allows using these two aspects of distributed neural network +training (using TensorFlow) and Apache Ignite partitioning. Ignite Dataset is a +computation graph operation that can be performed on a remote worker. The remote +worker can override Ignite Dataset parameters (such as `host`, `port` or `part`) +by setting correspondent environment variables for worker process (such as +`IGNITE_DATASET_HOST`, `IGNITE_DATASET_PORT` or `IGNITE_DATASET_PART`). Using +this overriding approach, we can assign a specific partition to every worker so +that one worker handles one partition and, at the same time, transparently work +with single dataset. ```python >>> import tensorflow as tf @@ -149,23 +159,31 @@ system called [IGFS](https://ignite.apache.org/features/igfs.html). IGFS delivers a similar functionality to Hadoop HDFS, but only in-memory. In fact, in addition to its own APIs, IGFS implements Hadoop FileSystem API and can be transparently plugged into Hadoop or Spark deployments. This contrib package -contains an integration between IGFS and TensorFlow. The integration is based -on [custom filesystem plugin](https://www.tensorflow.org/extend/add_filesys) -from TensorFlow side and +contains an integration between IGFS and TensorFlow. The integration is based on +[custom filesystem plugin](https://www.tensorflow.org/extend/add_filesys) from +TensorFlow side and [IGFS Native API](https://ignite.apache.org/features/igfs.html) from Apache -Ignite side. It has numerous uses, for example: * Checkpoints of state can be -saved to IGFS for reliability and fault-tolerance. * Training processes -communicate with TensorBoard by writing event files to a directory, which -TensorBoard watches. IGFS allows this communication to work even when -TensorBoard runs in a different process or machine. +Ignite side. It has numerous uses, for example: + +* Checkpoints of state can be saved to IGFS for reliability and + fault-tolerance. +* Training processes communicate with TensorBoard by writing event files to a + directory, which TensorBoard watches. IGFS allows this communication to work + even when TensorBoard runs in a different process or machine. ### SSL Connection -Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and authentification. Ignite Dataset supports both SSL connection with and without authntication. For more information, please refer to the [Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) documentation. +Apache Ignite allows to protect data transfer channels by +[SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and +authentication. Ignite Dataset supports both SSL connection with and without +authentication. For more information, please refer to the +[Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) +documentation. ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset +>>> tf.enable_eager_execution() >>> >>> dataset = IgniteDataset(cache_name="IMAGES", certfile="client.pem", @@ -186,7 +204,7 @@ Following examples will help you to easily start working with this module. The simplest way to try Ignite Dataset is to run a [Docker](https://www.docker.com/) container with Apache Ignite and loaded -[MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with +[MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interrupt with it using Ignite Dataset. Such container is available on Docker Hub: [dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). You need to start this container on your machine: @@ -197,13 +215,13 @@ docker run -it -p 10800:10800 dmitrievanthony/ignite-with-mnist After that you will be able to work with it following way: -![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist.png "Ignite Dataset Mnist") +![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist-2.png "Ignite Dataset Mnist") ### IGFS The simplest way to try IGFS with TensorFlow is to run [Docker](https://www.docker.com/) container with Apache Ignite and enabled IGFS -and then interruct with it using TensorFlow +and then interrupt with it using TensorFlow [tf.gfile](https://www.tensorflow.org/api_docs/python/tf/gfile). Such container is available on Docker Hub: [dmitrievanthony/ignite-with-igfs](https://hub.docker.com/r/dmitrievanthony/ignite-with-igfs/). diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py index e4762c91b193f9c5e32fa2642e702e61e8e5e57f..3ffceef8070e0fc3b3cebae2522f89fe98ce4413 100644 --- a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py +++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py @@ -31,6 +31,7 @@ from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import deprecation @six.add_metaclass(abc.ABCMeta) @@ -699,6 +700,10 @@ class IgniteDataset(dataset_ops.DatasetSource): Ignite Binary Client Protocol. """ + @deprecation.deprecated( + None, + "tf.contrib.ignite will be removed in 2.0, the support for Apache Ignite " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, cache_name, host="localhost", @@ -730,8 +735,6 @@ class IgniteDataset(dataset_ops.DatasetSource): cert_password: Password to be used if the private key is encrypted and a password is necessary. """ - super(IgniteDataset, self).__init__() - with IgniteClient(host, port, username, password, certfile, keyfile, cert_password) as client: client.handshake() @@ -755,6 +758,8 @@ class IgniteDataset(dataset_ops.DatasetSource): self.cache_type.to_output_types(), self.cache_type.to_output_shapes(), self.cache_type.to_output_classes()) + super(IgniteDataset, self).__init__(self._as_variant_tensor()) + def _as_variant_tensor(self): return gen_dataset_ops.ignite_dataset(self.cache_name, self.host, self.port, self.local, self.part, self.page_size, diff --git a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py index ff5d4c458c859fd8e5e3ae65ee41a454d55d6538..89b74fbfdc38c9f42795d5c778889210baf6387f 100644 --- a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py +++ b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py @@ -19,9 +19,9 @@ from __future__ import print_function import os +from tensorflow import compat from tensorflow.contrib.ignite import IgniteDataset from tensorflow.python.client import session -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.platform import test @@ -66,7 +66,7 @@ class IgniteDatasetTest(test.TestCase): self.assertEqual(dtypes.string, dataset.output_types["val"]["NAME"]) self.assertEqual(dtypes.int64, dataset.output_types["val"]["VAL"]) - it = dataset_ops.make_one_shot_iterator(dataset) + it = compat.v1.data.make_one_shot_iterator(dataset) ne = it.get_next() with session.Session() as sess: diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh old mode 100644 new mode 100755 diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index 2b86331099ccae03664462987ee0c141d766c10f..5591c3b0cc8c8bf196bb4821c018cbf155cba4ce 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -23,12 +23,17 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation class KafkaDataset(dataset_ops.DatasetSource): """A Kafka Dataset that consumes the message. """ + @deprecation.deprecated( + None, + "tf.contrib.kafka will be removed in 2.0, the support for Apache Kafka " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, topics, servers="localhost", @@ -47,7 +52,6 @@ class KafkaDataset(dataset_ops.DatasetSource): timeout: The timeout value for the Kafka Consumer to wait (in millisecond). """ - super(KafkaDataset, self).__init__() self._topics = ops.convert_to_tensor( topics, dtype=dtypes.string, name="topics") self._servers = ops.convert_to_tensor( @@ -58,6 +62,8 @@ class KafkaDataset(dataset_ops.DatasetSource): self._timeout = ops.convert_to_tensor( timeout, dtype=dtypes.int64, name="timeout") + super(KafkaDataset, self).__init__(self._as_variant_tensor()) + def _as_variant_tensor(self): return gen_dataset_ops.kafka_dataset(self._topics, self._servers, self._group, self._eof, self._timeout) diff --git a/tensorflow/contrib/keras/api/keras/losses/__init__.py b/tensorflow/contrib/keras/api/keras/losses/__init__.py index c4476a7bbd5056fa898468a46031bf3d8b1e44cf..b12832d2e2a3cccb4948d9e3bf3d226030121ac2 100644 --- a/tensorflow/contrib/keras/api/keras/losses/__init__.py +++ b/tensorflow/contrib/keras/api/keras/losses/__init__.py @@ -22,7 +22,7 @@ from __future__ import print_function from tensorflow.python.keras.losses import binary_crossentropy from tensorflow.python.keras.losses import categorical_crossentropy from tensorflow.python.keras.losses import categorical_hinge -from tensorflow.python.keras.losses import cosine_proximity +from tensorflow.python.keras.losses import cosine_similarity from tensorflow.python.keras.losses import hinge from tensorflow.python.keras.losses import kullback_leibler_divergence from tensorflow.python.keras.losses import logcosh diff --git a/tensorflow/contrib/keras/api/keras/metrics/__init__.py b/tensorflow/contrib/keras/api/keras/metrics/__init__.py index 7317fdb52c5b79e787a49d71be49f5261d6b1fff..095b5d798df9ac9038fa1088cdd402dff304e87e 100644 --- a/tensorflow/contrib/keras/api/keras/metrics/__init__.py +++ b/tensorflow/contrib/keras/api/keras/metrics/__init__.py @@ -23,7 +23,7 @@ from tensorflow.python.keras.metrics import binary_accuracy from tensorflow.python.keras.metrics import binary_crossentropy from tensorflow.python.keras.metrics import categorical_accuracy from tensorflow.python.keras.metrics import categorical_crossentropy -from tensorflow.python.keras.metrics import cosine_proximity +from tensorflow.python.keras.metrics import cosine_similarity from tensorflow.python.keras.metrics import hinge from tensorflow.python.keras.metrics import kullback_leibler_divergence from tensorflow.python.keras.metrics import mean_absolute_error diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py index 4ef0a66a52429233c6e6f70667a451466493629c..294a7d69a704b3c06ab9e30489af116929ab6c2a 100644 --- a/tensorflow/contrib/kernel_methods/python/losses.py +++ b/tensorflow/contrib/kernel_methods/python/losses.py @@ -34,7 +34,7 @@ def sparse_multiclass_hinge_loss( scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS): - """Adds Ops for computing the multiclass hinge loss. + r"""Adds Ops for computing the multiclass hinge loss. The implementation is based on the following paper: On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md index 42b91d031375b8edb7e4f364ac91ffb74ef1f54b..19daffea6c7e4486499388314d0aaaa611e94218 100644 --- a/tensorflow/contrib/kfac/README.md +++ b/tensorflow/contrib/kfac/README.md @@ -1,3 +1,3 @@ # K-FAC: Kronecker-Factored Approximate Curvature -## KFAC moved to third_party/tensorflow_kfac. +## KFAC moved to https://github.com/tensorflow/kfac. diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py index 20395395281768ac429984a1e3552cfd187527a2..9479afb180df7bb4a08d6aafa4fc3bf63489d9f3 100644 --- a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py @@ -23,6 +23,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation class KinesisDataset(dataset_ops.DatasetSource): @@ -50,6 +51,10 @@ class KinesisDataset(dataset_ops.DatasetSource): is returned immediately instead. """ + @deprecation.deprecated( + None, + "tf.contrib.kinesis will be removed in 2.0, the support for Kinesis " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, stream, shard="", @@ -66,7 +71,6 @@ class KinesisDataset(dataset_ops.DatasetSource): interval: The interval for the Kinesis Client to wait before it tries to get records again (in millisecond). """ - super(KinesisDataset, self).__init__() self._stream = ops.convert_to_tensor( stream, dtype=dtypes.string, name="stream") self._shard = ops.convert_to_tensor( @@ -75,6 +79,7 @@ class KinesisDataset(dataset_ops.DatasetSource): read_indefinitely, dtype=dtypes.bool, name="read_indefinitely") self._interval = ops.convert_to_tensor( interval, dtype=dtypes.int64, name="interval") + super(KinesisDataset, self).__init__(self._as_variant_tensor()) def _as_variant_tensor(self): return gen_dataset_ops.kinesis_dataset( diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD index 588f15b867c1fedbadd5a5d945d870a356549468..7e19ae7c13df421ec5bb9cb0e07dff0d00fb9548 100644 --- a/tensorflow/contrib/labeled_tensor/BUILD +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -155,7 +155,7 @@ py_library( ":core", "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:functional_ops", + "//tensorflow/python:map_fn", "//tensorflow/python:math_ops", "//tensorflow/python:numerics", "//tensorflow/python:random_ops", diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py index 2ede5daee74223e812cc29e9708b1989b698fb4e..a65f045cc886f4d4f351423858d92412baa3a622 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import map_fn as map_fn_lib from tensorflow.python.ops import math_ops from tensorflow.python.ops import numerics from tensorflow.python.ops import random_ops @@ -629,7 +630,7 @@ def map_fn(fn, labeled_tensor, name=None): # TODO(ericmc): Fix this upstream. if labeled_tensor.dtype == dtypes.string: - # We must construct the full graph here, because functional_ops.map_fn + # We must construct the full graph here, because map_fn_lib.map_fn # doesn't work for string-valued tensors. # Constructing the full graph may be slow. map_lts = [fn(t) for t in unpack_lts] @@ -652,7 +653,7 @@ def map_fn(fn, labeled_tensor, name=None): tensor_lt = core.LabeledTensor(tensor, original_axes) return fn(tensor_lt).tensor - map_op = functional_ops.map_fn( + map_op = map_fn_lib.map_fn( tf_fn, labeled_tensor.tensor, dtype=first_map_lt.dtype) map_lt = core.LabeledTensor(map_op, final_axes) diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 9ca6f8df5dbe3c236c4cd85095176ce69ad9deaa..69d5496f8aebb9b89c5d79f80a1a439f556093d7 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -81,6 +81,7 @@ tf_custom_op_py_library( visibility = [ "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", + "//tensorflow_model_optimization:__subpackages__", "//video/youtube/personalization:__subpackages__", ], deps = [ diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index 7e6eafaa0d6f60cfc28a4c422abac0b6d5a991fb..00e41026d0038409ace178e6affd2c1cdc812122 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -1757,7 +1757,7 @@ class WeightedSumTest(test.TestCase): logits_core = fc_core.linear_model(features, [movies]) with self.cached_session() as sess: - variables_lib.initialize_all_variables().run() + variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() weights = column_to_variable[movies][0] diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 403b522ce45ac6ad98a321378626b87aaa7738aa..9d9524e4e4b995d795b7c71b5bd083d11c60d5ce 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -2308,7 +2308,7 @@ def layer_norm(inputs, initializer=init_ops.ones_initializer(), collections=gamma_collections, trainable=trainable) - # Calculate the moments on the last axis (layer activations). + # By default, compute the moments across all the dimensions except the one with index 0. norm_axes = list(range(begin_norm_axis, inputs_rank)) mean, variance = nn.moments(inputs, norm_axes, keep_dims=True) # Compute layer normalization using the batch_normalization function. diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index d791418c9d0f887058ceb535092fa8122da1aa75..1c0088186c030437454c0f764decab9e5a276adc 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1356,7 +1356,7 @@ class DropoutTest(test.TestCase): with self.cached_session(): images = np.random.uniform(size=(5, height, width, 3)) output = _layers.dropout(images) - self.assertEqual(output.op.name, 'Dropout/dropout_1/mul') + self.assertEqual(output.op.name, 'Dropout/dropout_1/mul_1') output.get_shape().assert_is_compatible_with( ops.convert_to_tensor(images).get_shape()) diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py index 11033a2e9cb646c2e7cd2f45de1f751d88c6921a..76b03ff514821d3459f84c5f46a64d1134e0d4de 100644 --- a/tensorflow/contrib/layers/python/layers/normalization.py +++ b/tensorflow/contrib/layers/python/layers/normalization.py @@ -186,7 +186,7 @@ def group_norm(inputs, Args: inputs: A Tensor with at least 2 dimensions one which is channels. All - shape dimensions must be fully defined. + shape dimensions except for batch must be fully defined. groups: Integer. Divide the channels into this number of groups over which normalization statistics are computed. This number must be commensurate with the number of channels in `inputs`. @@ -249,13 +249,21 @@ def group_norm(inputs, """ # TODO(shlens): Support partially defined shapes for the inputs. inputs = ops.convert_to_tensor(inputs) - original_shape = inputs.shape if inputs.shape.ndims is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) if channels_axis > (inputs.shape.ndims - 1): raise ValueError('Axis is out of bounds.') + # Use dynamic shape for not fully defined dimensions in the inputs. + dyanmic_shape = array_ops.shape(inputs) + input_shape_list = [] + for i, dim in enumerate(inputs.shape): + if dim.value is None: + input_shape_list.append(dyanmic_shape[i]) + else: + input_shape_list.append(dim) + # Standardize the channels_axis to be positive and identify # of channels. if channels_axis < 0: channels_axis = inputs.shape.ndims + channels_axis @@ -289,8 +297,8 @@ def group_norm(inputs, # Determine axes before channels. Some examples of common image formats: # 'NCHW': before = [N], after = [HW] # 'NHWC': before = [NHW], after = [] - axes_before_channels = inputs.shape.as_list()[:channels_axis] - axes_after_channels = inputs.shape.as_list()[channels_axis+1:] + axes_before_channels = input_shape_list[:channels_axis] + axes_after_channels = input_shape_list[channels_axis+1:] # Manually broadcast the parameters to conform to the number of groups. params_shape_broadcast = ([1] * len(axes_before_channels) + @@ -369,7 +377,7 @@ def group_norm(inputs, outputs = inputs * gain + offset # Collapse the groups into the channel dimension. - outputs = array_ops.reshape(outputs, original_shape) + outputs = array_ops.reshape(outputs, input_shape_list) if activation_fn is not None: outputs = activation_fn(outputs) diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py index c8d3c91b10dbe3b959e91182f9924b78352d370d..9a85084b239837ade87d8c778393ef8e885f5bdd 100644 --- a/tensorflow/contrib/layers/python/layers/normalization_test.py +++ b/tensorflow/contrib/layers/python/layers/normalization_test.py @@ -221,6 +221,15 @@ class GroupNormTest(test.TestCase): normalization.group_norm(inputs, channels_axis=-1, reduction_axes=[-3, -2]) + def testParamsShapeNotFullyDefinedBatchAxis(self): + height, width, groups = 3, 3, 4 + inputs = array_ops.placeholder(dtypes.float32, + shape=(None, height, width, 2*groups)) + output = normalization.group_norm(inputs, channels_axis=-1, + reduction_axes=[-3, -2], groups=groups) + self.assertListEqual([None, height, width, 2 * groups], + output.shape.as_list()) + def testCreateOp(self): height, width, groups = 3, 3, 4 images = random_ops.random_uniform((5, height, width, 2*groups), seed=1) diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py index 8a6b4f68a8b33d497ddb16614a7e3cdf32f2c422..5234869718b427d7e275b76ae12021a096241a56 100644 --- a/tensorflow/contrib/layers/python/layers/target_column.py +++ b/tensorflow/contrib/layers/python/layers/target_column.py @@ -399,7 +399,7 @@ def _mean_squared_loss(logits, target): target = array_ops.expand_dims(target, axis=1) logits.get_shape().assert_is_compatible_with(target.get_shape()) - return math_ops.square(logits - math_ops.to_float(target)) + return math_ops.squared_difference(logits, math_ops.to_float(target)) def _log_loss_with_two_classes(logits, target): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 14065fcee51c014a1af227504eaaca1fa39941e1..4749371248ee89a033912132986d7f76c85dbaa6 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -357,9 +357,9 @@ py_test( py_test( name = "dnn_linear_combined_test", - size = "large", + size = "medium", srcs = ["python/learn/estimators/dnn_linear_combined_test.py"], - shard_count = 4, + shard_count = 8, srcs_version = "PY2AND3", tags = ["no_oss"], # flaky b/70524820 deps = [ diff --git a/tensorflow/contrib/learn/README.md b/tensorflow/contrib/learn/README.md index b0bff915a993c9a01e2e6d9ef9f71c14d2f29a73..b2d3a6273abba7e3a893f30bbdd4f8b2662bd54a 100644 --- a/tensorflow/contrib/learn/README.md +++ b/tensorflow/contrib/learn/README.md @@ -111,18 +111,17 @@ Some arguments are renamed, please refer to documentation. In addition: Switch to `tf.estimator.train_and_evaluate`. Some differences: -* Most of the constructor arguments, like `train_input_fn`, `eval_input_fn`, - should be wrapped into `tf.estimator.TrainSpec` and `tf.estimator.EvalSpec`. -* Remove the `experiment_fn`. Instead, create the `Estimator`, - `train_spec` and `eval_spec`, then call `tf.estimator.train_and_evaluate` - directly. -* Inside `tf.estimator.EvalSpec`, the `exporter` field is the replacement - for `export_strategy`. To be precise, `tf.estimator.LatestExporter` is the - replacement for `tf.contrib.learn.make_export_strategy`. If you want to export - only at the end of training use `tf.estimator.FinalExporter`. -* If the `TF_CONFIG` environment variable is constructed manually, please read - the `train_and_evaluate` documentation for the new requirementds (in - particular, the chief node and evaluator node). +* Most of the constructor arguments, like `train_input_fn`, `eval_input_fn`, + should be wrapped into `tf.estimator.TrainSpec` and `tf.estimator.EvalSpec`. +* Remove the `experiment_fn`. Instead, create the `Estimator`, `train_spec` + and `eval_spec`, then call `tf.estimator.train_and_evaluate` directly. +* Inside `tf.estimator.EvalSpec`, the `exporter` field is the replacement for + `export_strategy`. To be precise, `tf.estimator.LatestExporter` is the + replacement for `tf.contrib.learn.make_export_strategy`. If you want to + export only at the end of training use `tf.estimator.FinalExporter`. +* If the `TF_CONFIG` environment variable is constructed manually, please read + the `train_and_evaluate` documentation for the new requirements (in + particular, the chief node and evaluator node). ## Others Classes and Functions diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index 28c4964527bb034c8c6b1642366c6c82c1a72201..c3e9e3af9427037a4e7be6b86417cd081c42ef67 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -37,8 +37,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn_cell @@ -524,7 +524,7 @@ class DynamicRNNEstimatorLearningTest(test.TestCase): def input_fn(): starts = random_ops.random_uniform( [batch_size], maxval=(2 * np.pi), seed=seed) - sin_curves = functional_ops.map_fn( + sin_curves = map_fn.map_fn( _sin_fn, (starts,), dtype=dtypes.float32) inputs = array_ops.expand_dims( array_ops.slice(sin_curves, [0, 0], [batch_size, sequence_length]), diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index 8a461a0bd7ba457fcf830769f23c6ca2860a2732..cbcae338a0a195da2aca1eea2e1b4c7eb8b0e35e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -1181,14 +1181,14 @@ class EstimatorTest(test.TestCase): ] self.assertItemsEqual([expected_vocab_file], assets) graph_ops = [x.name for x in graph.get_operations()] - self.assertTrue('input_example_tensor' in graph_ops) - self.assertTrue('ParseExample/ParseExample' in graph_ops) - self.assertTrue('linear/linear/feature/matmul' in graph_ops) + self.assertIn('input_example_tensor', graph_ops) + self.assertIn('ParseExample/ParseExample', graph_ops) + self.assertIn('linear/linear/feature/matmul', graph_ops) # Since there were no transforms, both save ops are still present. - self.assertTrue('save/SaveV2/tensor_names' in graph_ops) - self.assertTrue('save_1/SaveV2/tensor_names' in graph_ops) + self.assertIn('save/SaveV2/tensor_names', graph_ops) + self.assertIn('save_1/SaveV2/tensor_names', graph_ops) # Since there were no transforms, the hash table lookup is still there. - self.assertTrue('hash_table_Lookup' in graph_ops) + self.assertIn('hash_table_Lookup/LookupTableFindV2', graph_ops) # Restore, to validate that the export was well-formed. # tag_2, tag_3 was subjected to strip_unused_nodes. diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index c1b97d8b49613ea49d9813954da3b7a63d3ba04c..4bb14a6e63b159fa4d09c9ef20947d4b125de657 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -567,7 +567,8 @@ def _mean_squared_loss(labels, logits, weights=None): if len(logits.get_shape()) == 1: logits = array_ops.expand_dims(logits, axis=1) logits.get_shape().assert_is_compatible_with(labels.get_shape()) - loss = math_ops.square(logits - math_ops.to_float(labels), name=name) + loss = math_ops.squared_difference( + logits, math_ops.to_float(labels), name=name) return _compute_weighted_loss(loss, weights) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 7c2d9bb0767cb979dae9c84b5342d129225677ed..a52d25acf402bdda46771e9146a40cfb71e99d53 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -62,8 +62,8 @@ def _assert_no_variables(test_case): def _assert_metrics(test_case, expected_loss, expected_eval_metrics, model_fn_ops): test_case.assertAlmostEqual(expected_loss, model_fn_ops.loss.eval(), places=4) - for k in six.iterkeys(expected_eval_metrics): - test_case.assertIn(k, six.iterkeys(model_fn_ops.eval_metric_ops)) + for k in expected_eval_metrics: + test_case.assertIn(k, model_fn_ops.eval_metric_ops) variables.initialize_local_variables().run() for key, expected_value in six.iteritems(expected_eval_metrics): value_tensor, update_tensor = model_fn_ops.eval_metric_ops[key] @@ -545,19 +545,19 @@ class MultiLabelHeadTest(test.TestCase): with session.Session(): self.assertListEqual( [1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0]) - self.assertItemsEqual( - ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertItemsEqual(["head_name"], + list(model_fn_ops.output_alternatives)) self.assertEqual( constants.ProblemType.CLASSIFICATION, model_fn_ops.output_alternatives["head_name"][0]) predictions_for_serving = ( model_fn_ops.output_alternatives["head_name"][1]) - self.assertIn("classes", six.iterkeys(predictions_for_serving)) + self.assertIn("classes", predictions_for_serving) self.assertAllEqual( [[b"0", b"1", b"2"], [b"0", b"1", b"2"]], predictions_for_serving["classes"].eval()) - self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) + self.assertIn("probabilities", predictions_for_serving) self.assertAllClose( [[0.731059, 0.5, 0.5], [0.5, 0.5, 0.731059,]], @@ -850,18 +850,18 @@ class BinaryClassificationHeadTest(test.TestCase): with session.Session(): self.assertListEqual( [1, 1], list(model_fn_ops.predictions["classes"].eval())) - self.assertItemsEqual( - ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertItemsEqual(["head_name"], + list(model_fn_ops.output_alternatives)) self.assertEqual( constants.ProblemType.LOGISTIC_REGRESSION, model_fn_ops.output_alternatives["head_name"][0]) predictions_for_serving = ( model_fn_ops.output_alternatives["head_name"][1]) - self.assertIn("classes", six.iterkeys(predictions_for_serving)) + self.assertIn("classes", predictions_for_serving) predicted_classes = predictions_for_serving["classes"].eval().tolist() self.assertListEqual( [b"0", b"1"], predicted_classes[0]) - self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) + self.assertIn("probabilities", predictions_for_serving) def testBinaryClassificationInferMode_withWeightColumn(self): n_classes = 2 @@ -1349,18 +1349,18 @@ class MultiClassHeadTest(test.TestCase): self.assertAllEqual( [0, 2], model_fn_ops.predictions["classes"].eval()) - self.assertItemsEqual( - ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertItemsEqual(["head_name"], + list(model_fn_ops.output_alternatives)) self.assertEqual( constants.ProblemType.CLASSIFICATION, model_fn_ops.output_alternatives["head_name"][0]) predictions_for_serving = ( model_fn_ops.output_alternatives["head_name"][1]) - self.assertIn("classes", six.iterkeys(predictions_for_serving)) + self.assertIn("classes", predictions_for_serving) self.assertAllEqual( [[b"0", b"1", b"2"], [b"0", b"1", b"2"]], predictions_for_serving["classes"].eval()) - self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) + self.assertIn("probabilities", predictions_for_serving) self.assertAllClose( [[0.576117, 0.2119416, 0.2119416], [0.2119416, 0.2119416, 0.576117]], @@ -1401,18 +1401,18 @@ class MultiClassHeadTest(test.TestCase): self.assertAllEqual( [b"key0", b"key2"], model_fn_ops.predictions["classes"].eval()) - self.assertItemsEqual( - ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertItemsEqual(["head_name"], + list(model_fn_ops.output_alternatives)) self.assertEqual( constants.ProblemType.CLASSIFICATION, model_fn_ops.output_alternatives["head_name"][0]) predictions_for_serving = ( model_fn_ops.output_alternatives["head_name"][1]) - self.assertIn("classes", six.iterkeys(predictions_for_serving)) + self.assertIn("classes", predictions_for_serving) self.assertAllEqual( [[b"key0", b"key1", b"key2"], [b"key0", b"key1", b"key2"]], predictions_for_serving["classes"].eval()) - self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) + self.assertIn("probabilities", predictions_for_serving) self.assertAllClose( [[0.576117, 0.2119416, 0.2119416], [0.2119416, 0.2119416, 0.576117]], diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py index 5e90d1fa20535de3b5e25bc7ff8c3862cea5514c..318046733bf75a6d661d26f478118c8e944afe15 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py @@ -174,7 +174,7 @@ class GeneratorIoTest(test.TestCase): return np.arange(32, 36) with self.cached_session(): - with self.assertRaisesRegexp(TypeError, 'x\(\) must be generator'): + with self.assertRaisesRegexp(TypeError, r'x\(\) must be generator'): failing_input_fn = generator_io.generator_input_fn( generator, batch_size=2, shuffle=False, num_epochs=1) failing_input_fn() @@ -185,7 +185,7 @@ class GeneratorIoTest(test.TestCase): yield np.arange(32, 36) with self.cached_session(): - with self.assertRaisesRegexp(TypeError, 'x\(\) must yield dict'): + with self.assertRaisesRegexp(TypeError, r'x\(\) must yield dict'): failing_input_fn = generator_io.generator_input_fn( generator, batch_size=2, shuffle=False, num_epochs=1) failing_input_fn() diff --git a/tensorflow/contrib/learn/python/learn/utils/gc_test.py b/tensorflow/contrib/learn/python/learn/utils/gc_test.py index e7d091e18a8f186f89f5217442c24fb106c5cdab..af93e517f51ed33a8968982945ac1f65ec915ab1 100644 --- a/tensorflow/contrib/learn/python/learn/utils/gc_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/gc_test.py @@ -36,10 +36,10 @@ def _create_parser(base_dir): # Modify the path object for RegEx match for Windows Paths if os.name == "nt": match = re.match( - "^" + compat.as_str_any(base_dir).replace("\\", "/") + "/(\\d+)$", + r"^" + compat.as_str_any(base_dir).replace("\\", "/") + r"/(\d+)$", compat.as_str_any(path.path).replace("\\", "/")) else: - match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$", + match = re.match(r"^" + compat.as_str_any(base_dir) + r"/(\d+)$", compat.as_str_any(path.path)) if not match: return None diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index a28394964a12013c43d85701b5a0ab5c559afd62..8fda828e994bc2436eaba4475077020436703631 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -36,7 +36,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.util import deprecation -# TODO(rohanj): This should subclass Checkpointable and implement +# TODO(rohanj): This should subclass Trackable and implement # _gather_saveables_for_checkpoint. class ShardedMutableDenseHashTable(object): """A sharded version of MutableDenseHashTable. diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index e52fb5ab1431e086f99b4033a6216636a83bad79..3d21fb68a1452c97f7eb85491fc850d9e846266a 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools - from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -28,7 +26,6 @@ from tensorflow.python.ops import lookup_ops # pylint: disable=unused-import from tensorflow.python.ops.lookup_ops import FastHashSpec from tensorflow.python.ops.lookup_ops import HasherSpec -from tensorflow.python.ops.lookup_ops import HashTable from tensorflow.python.ops.lookup_ops import IdTableWithHashBuckets from tensorflow.python.ops.lookup_ops import index_table_from_file from tensorflow.python.ops.lookup_ops import index_to_string_table_from_file @@ -42,7 +39,6 @@ from tensorflow.python.ops.lookup_ops import TextFileIndex from tensorflow.python.ops.lookup_ops import TextFileInitializer from tensorflow.python.ops.lookup_ops import TextFileStringTableInitializer # pylint: enable=unused-import -from tensorflow.python.training.saver import BaseSaverBuilder from tensorflow.python.util.deprecation import deprecated @@ -91,7 +87,7 @@ def index_table_from_tensor(mapping, The bucket ID range is `[mapping size, mapping size + num_oov_buckets - 1]`. The underlying table must be initialized by calling - `tf.tables_initializer.run()` or `table.initializer.run()` once. + `session.run(tf.tables_initializer)` or `session.run(table.init)` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -158,7 +154,7 @@ def string_to_index(tensor, mapping, default_value=-1, name=None): will throw a FailedPreconditionError. The underlying table must be initialized by calling - `tf.tables_initializer.run()` once. + `session.run(tf.tables_initializer)` once. For example: @@ -202,7 +198,7 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `tf.tables_initializer.run()` or `table.initializer.run()` once. + `session.run(tf.tables_initializer)` or `session.run(table.init)` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -257,7 +253,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `tf.tables_initializer.run()` once. + `session.run(tf.tables_initializer)` once. For example: @@ -288,353 +284,52 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): return table.lookup(tensor) -class MutableHashTable(LookupInterface): - """A generic mutable hash table implementation. - - Data can be inserted by calling the insert method and removed by calling the - remove method. It does not support initialization via the init method. +class HashTable(InitializableLookupTableBase): + """A generic hash table implementation. Example usage: ```python - table = tf.contrib.lookup.MutableHashTable(key_dtype=tf.string, - value_dtype=tf.int64, - default_value=-1) - sess.run(table.insert(keys, values)) - out = table.lookup(query_keys) + table = tf.HashTable( + tf.KeyValueTensorInitializer(keys, values), -1) + out = table.lookup(input_tensor) + table.init.run() print(out.eval()) ``` """ - def __init__(self, - key_dtype, - value_dtype, - default_value, - shared_name=None, - name="MutableHashTable", - checkpoint=True): - """Creates an empty `MutableHashTable` object. + def __init__(self, initializer, default_value, shared_name=None, name=None): + """Creates a non-initialized `HashTable` object. - Creates a table, the type of its keys and values are specified by key_dtype - and value_dtype, respectively. + Creates a table, the type of its keys and values are specified by the + initializer. + Before using the table you will have to initialize it. After initialization + the table will be immutable. Args: - key_dtype: the type of the key tensors. - value_dtype: the type of the value tensors. + initializer: The table initializer to use. See `HashTable` kernel for + supported key and value types. default_value: The value to use if a key is missing in the table. - shared_name: If non-empty, this table will be shared under - the given name across multiple sessions. + shared_name: If non-empty, this table will be shared under the given name + across multiple sessions. name: A name for the operation (optional). - checkpoint: if True, the contents of the table are saved to and restored - from checkpoints. If `shared_name` is empty for a checkpointed table, it - is shared using the table node name. Returns: - A `MutableHashTable` object. - - Raises: - ValueError: If checkpoint is True and no name was specified. + A `HashTable` object. """ - self._default_value = ops.convert_to_tensor(default_value, - dtype=value_dtype) - self._value_shape = self._default_value.get_shape() - self._checkpoint = checkpoint - self._key_dtype = key_dtype - self._value_dtype = value_dtype - self._name = name - - if context.executing_eagerly() and shared_name is None: - # TODO(allenl): This will leak memory due to kernel caching by the - # shared_name attribute value (but is better than the alternative of - # sharing everything by default when executing eagerly; hopefully creating - # tables in a loop is uncommon). - shared_name = "table_%d" % (ops.uid(),) + self._initializer = initializer + self._default_value = default_value self._shared_name = shared_name - super(MutableHashTable, self).__init__(key_dtype, value_dtype) - - self._resource_handle = self.create_resource() - if checkpoint: - saveable = MutableHashTable._Saveable(self, name) - if not context.executing_eagerly(): - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - - def create_resource(self): - # The table must be shared if checkpointing is requested for multi-worker - # training to work correctly. Use the node name if no shared_name has been - # explicitly specified. - use_node_name_sharing = self._checkpoint and self._shared_name is None - if self._default_value.get_shape().ndims == 0: - table_ref = gen_lookup_ops.mutable_hash_table_v2( - shared_name=self._shared_name, - use_node_name_sharing=use_node_name_sharing, - key_dtype=self._key_dtype, - value_dtype=self._value_dtype, - name=self._name) - else: - table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2( - shared_name=self._shared_name, - use_node_name_sharing=use_node_name_sharing, - key_dtype=self._key_dtype, - value_dtype=self._value_dtype, - value_shape=self._default_value.get_shape(), - name=self._name) - - if context.executing_eagerly(): - self._table_name = None - else: - self._table_name = table_ref.op.name.split("/")[-1] - return table_ref - - @property - def name(self): - return self._table_name - - def size(self, name=None): - """Compute the number of elements in this table. - - Args: - name: A name for the operation (optional). - - Returns: - A scalar tensor containing the number of elements in this table. - """ - with ops.name_scope(name, "%s_Size" % self.name, - [self.resource_handle]) as name: - with ops.colocate_with(self.resource_handle): - return gen_lookup_ops.lookup_table_size_v2( - self.resource_handle, name=name) - - def remove(self, keys, name=None): - """Removes `keys` and its associated values from the table. - - If a key is not present in the table, it is silently ignored. - - Args: - keys: Keys to remove. Can be a tensor of any shape. Must match the table's - key type. - name: A name for the operation (optional). - - Returns: - The created Operation. - - Raises: - TypeError: when `keys` do not match the table data types. - """ - if keys.dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % - (self._key_dtype, keys.dtype)) - - with ops.name_scope( - name, "%s_lookup_table_remove" % self.name, - (self.resource_handle, keys, self._default_value)) as name: - # pylint: disable=protected-access - op = gen_lookup_ops.lookup_table_remove_v2( - self.resource_handle, keys, name=name) - - return op - - def lookup(self, keys, name=None): - """Looks up `keys` in a table, outputs the corresponding values. - - The `default_value` is used for keys not present in the table. - - Args: - keys: Keys to look up. Can be a tensor of any shape. Must match the - table's key_dtype. - name: A name for the operation (optional). - - Returns: - A tensor containing the values in the same shape as `keys` using the - table's value type. - - Raises: - TypeError: when `keys` do not match the table data types. - """ - with ops.name_scope( - name, "%s_lookup_table_find" % self.name, - (self.resource_handle, keys, self._default_value)) as name: - keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") - with ops.colocate_with(self.resource_handle): - values = gen_lookup_ops.lookup_table_find_v2( - self.resource_handle, keys, self._default_value, name=name) - return values - - def insert(self, keys, values, name=None): - """Associates `keys` with `values`. - - Args: - keys: Keys to insert. Can be a tensor of any shape. Must match the - table's key type. - values: Values to be associated with keys. Must be a tensor of the same - shape as `keys` and match the table's value type. - name: A name for the operation (optional). - - Returns: - The created Operation. - - Raises: - TypeError: when `keys` or `values` doesn't match the table data - types. - """ - with ops.name_scope(name, "%s_lookup_table_insert" % self.name, - [self.resource_handle, keys, values]) as name: - keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys") - values = ops.convert_to_tensor(values, self._value_dtype, name="values") - with ops.colocate_with(self.resource_handle): - # pylint: disable=protected-access - op = gen_lookup_ops.lookup_table_insert_v2( - self.resource_handle, keys, values, name=name) - return op - - def export(self, name=None): - """Returns tensors of all keys and values in the table. - - Args: - name: A name for the operation (optional). - - Returns: - A pair of tensors with the first tensor containing all keys and the - second tensors containing all values in the table. - """ - with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, - [self.resource_handle]) as name: - with ops.colocate_with(self.resource_handle): - exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( - self.resource_handle, self._key_dtype, self._value_dtype, name=name) - return exported_keys, exported_values - - def _gather_saveables_for_checkpoint(self): - """For object-based checkpointing.""" - return {"table": functools.partial(MutableHashTable._Saveable, table=self)} - - class _Saveable(BaseSaverBuilder.SaveableObject): - """SaveableObject implementation for MutableHashTable.""" - - def __init__(self, table, name): - tensors = table.export() - specs = [ - BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), - BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") - ] - # pylint: disable=protected-access - super(MutableHashTable._Saveable, self).__init__(table, specs, name) - - def restore(self, restored_tensors, restored_shapes): - del restored_shapes # unused - # pylint: disable=protected-access - with ops.colocate_with(self.op.resource_handle): - return gen_lookup_ops.lookup_table_import_v2( - self.op.resource_handle, restored_tensors[0], restored_tensors[1]) - - -class MutableDenseHashTable(LookupInterface): - """A generic mutable hash table implementation using tensors as backing store. - - Data can be inserted by calling the insert method and removed by calling the - remove method. It does not support initialization via the init method. - - It uses "open addressing" with quadratic reprobing to resolve collisions. - Compared to `MutableHashTable` the insert, remove and lookup operations in a - `MutableDenseHashTable` are typically faster, but memory usage can be higher. - However, `MutableDenseHashTable` does not require additional memory for - temporary tensors created during checkpointing and restore operations. - - Example usage: - - ```python - table = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int64, - value_dtype=tf.int64, - default_value=-1, - empty_key=0, - deleted_key=-1) - - sess.run(table.insert(keys, values)) - out = table.lookup(query_keys) - print(out.eval()) - ``` - """ - - # TODO(andreasst): consider extracting common code with MutableHashTable into - # a common superclass. - def __init__(self, - key_dtype, - value_dtype, - default_value, - empty_key, - deleted_key, - initial_num_buckets=None, - shared_name=None, - name="MutableDenseHashTable", - checkpoint=True): - """Creates an empty `MutableDenseHashTable` object. - - Creates a table, the type of its keys and values are specified by key_dtype - and value_dtype, respectively. - - Args: - key_dtype: the type of the key tensors. - value_dtype: the type of the value tensors. - default_value: The value to use if a key is missing in the table. - empty_key: the key to use to represent empty buckets internally. Must not - be used in insert, remove or lookup operations. - initial_num_buckets: the initial number of buckets. - shared_name: If non-empty, this table will be shared under - the given name across multiple sessions. - name: A name for the operation (optional). - checkpoint: if True, the contents of the table are saved to and restored - from checkpoints. If `shared_name` is empty for a checkpointed table, it - is shared using the table node name. - deleted_key: the key to use to represent deleted buckets internally. Must - not be used in insert, remove or lookup operations and be different from - the empty_key. - - Returns: - A `MutableDenseHashTable` object. - - Raises: - ValueError: If checkpoint is True and no name was specified. - """ - self._default_value = ops.convert_to_tensor( - default_value, dtype=value_dtype, name="default_value") - self._key_dtype = key_dtype - self._value_dtype = value_dtype - self._initial_num_buckets = initial_num_buckets + self._name = name or "hash_table" + self._table_name = None + super(HashTable, self).__init__(default_value, initializer) self._value_shape = self._default_value.get_shape() - self._checkpoint = checkpoint - self._name = name - - self._empty_key = ops.convert_to_tensor( - empty_key, dtype=key_dtype, name="empty_key") - self._deleted_key = ops.convert_to_tensor( - deleted_key, dtype=key_dtype, name="deleted_key") - if context.executing_eagerly() and shared_name is None: - # TODO(allenl): This will leak memory due to kernel caching by the - # shared_name attribute value (but is better than the alternative of - # sharing everything by default when executing eagerly; hopefully creating - # tables in a loop is uncommon). - shared_name = "table_%d" % (ops.uid(),) - self._shared_name = shared_name - super(MutableDenseHashTable, self).__init__(key_dtype, value_dtype) - - self._resource_handle = self.create_resource() - if checkpoint: - saveable = MutableDenseHashTable._Saveable(self, name) - if not context.executing_eagerly(): - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) def create_resource(self): - # The table must be shared if checkpointing is requested for multi-worker - # training to work correctly. Use the node name if no shared_name has been - # explicitly specified. - use_node_name_sharing = self._checkpoint and self._shared_name is None - table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( - empty_key=self._empty_key, - deleted_key=self._deleted_key, + table_ref = gen_lookup_ops.hash_table_v2( shared_name=self._shared_name, - use_node_name_sharing=use_node_name_sharing, - value_dtype=self._value_dtype, - value_shape=self._value_shape, - initial_num_buckets=self._initial_num_buckets, + key_dtype=self._initializer.key_dtype, + value_dtype=self._initializer.value_dtype, name=self._name) if context.executing_eagerly(): self._table_name = None @@ -646,103 +341,6 @@ class MutableDenseHashTable(LookupInterface): def name(self): return self._table_name - def size(self, name=None): - """Compute the number of elements in this table. - - Args: - name: A name for the operation (optional). - - Returns: - A scalar tensor containing the number of elements in this table. - """ - with ops.name_scope(name, "%s_Size" % self.name, - [self.resource_handle]) as name: - with ops.colocate_with(self.resource_handle): - return gen_lookup_ops.lookup_table_size_v2( - self.resource_handle, name=name) - - def lookup(self, keys, name=None): - """Looks up `keys` in a table, outputs the corresponding values. - - The `default_value` is used for keys not present in the table. - - Args: - keys: Keys to look up. Can be a tensor of any shape. Must match the - table's key_dtype. - name: A name for the operation (optional). - - Returns: - A tensor containing the values in the same shape as `keys` using the - table's value type. - - Raises: - TypeError: when `keys` do not match the table data types. - """ - with ops.name_scope(name, "%s_lookup_table_find" % self.name, - [self.resource_handle, keys]) as name: - keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") - with ops.colocate_with(self.resource_handle): - values = gen_lookup_ops.lookup_table_find_v2( - self.resource_handle, keys, self._default_value, name=name) - - return values - - def insert(self, keys, values, name=None): - """Associates `keys` with `values`. - - Args: - keys: Keys to insert. Can be a tensor of any shape. Must match the - table's key type. - values: Values to be associated with keys. Must be a tensor of the same - shape as `keys` and match the table's value type. - name: A name for the operation (optional). - - Returns: - The created Operation. - - Raises: - TypeError: when `keys` or `values` doesn't match the table data - types. - """ - with ops.name_scope(name, "%s_lookup_table_insert" % self.name, - [self.resource_handle, keys, values]) as name: - keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") - values = ops.convert_to_tensor( - values, dtype=self._value_dtype, name="values") - with ops.colocate_with(self.resource_handle): - op = gen_lookup_ops.lookup_table_insert_v2( - self.resource_handle, keys, values, name=name) - return op - - def remove(self, keys, name=None): - """Removes `keys` and its associated values from the table. - - If a key is not present in the table, it is silently ignored. - - Args: - keys: Keys to remove. Can be a tensor of any shape. Must match the table's - key type. - name: A name for the operation (optional). - - Returns: - The created Operation. - - Raises: - TypeError: when `keys` do not match the table data types. - """ - if keys.dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % - (self._key_dtype, keys.dtype)) - - with ops.name_scope( - name, "%s_lookup_table_remove" % self.name, - (self.resource_handle, keys, self._default_value)) as name: - # pylint: disable=protected-access - op = gen_lookup_ops.lookup_table_remove_v2( - self.resource_handle, keys, name=name) - - return op - def export(self, name=None): """Returns tensors of all keys and values in the table. @@ -753,34 +351,15 @@ class MutableDenseHashTable(LookupInterface): A pair of tensors with the first tensor containing all keys and the second tensors containing all values in the table. """ - with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, + with ops.name_scope(name, "%s_Export" % self.name, [self.resource_handle]) as name: - with ops.colocate_with(self.resource_handle): - exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( - self.resource_handle, self._key_dtype, self._value_dtype, name=name) + exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( + self.resource_handle, self._key_dtype, self._value_dtype, name=name) + exported_values.set_shape(exported_keys.get_shape().concatenate( + self._value_shape)) return exported_keys, exported_values - def _gather_saveables_for_checkpoint(self): - """For object-based checkpointing.""" - return {"table": functools.partial( - MutableDenseHashTable._Saveable, table=self)} - - class _Saveable(BaseSaverBuilder.SaveableObject): - """SaveableObject implementation for MutableDenseHashTable.""" - - def __init__(self, table, name): - tensors = table.export() - specs = [ - BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), - BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") - ] - # pylint: disable=protected-access - super(MutableDenseHashTable._Saveable, self).__init__(table, specs, name) - - def restore(self, restored_tensors, restored_shapes): - del restored_shapes # unused - # pylint: disable=protected-access - with ops.colocate_with(self.op.resource_handle): - return gen_lookup_ops.lookup_table_import_v2( - self.op.resource_handle, restored_tensors[0], restored_tensors[1]) + +MutableHashTable = lookup_ops.MutableHashTable +MutableDenseHashTable = lookup_ops.MutableDenseHashTable diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 9b2c2dd87cc8a92fbb6b45504939be3788b60839..9fe8dafcc8edd6b80625c61a4a0e783e65b44720 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -18,14 +18,10 @@ from __future__ import division from __future__ import print_function import os -import tempfile import numpy as np -import six from tensorflow.contrib import lookup from tensorflow.python.client import session -from tensorflow.python.data.experimental.ops import counter -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -37,9 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import saver from tensorflow.python.training import server_lib -from tensorflow.python.training.checkpointable import util as checkpointable class HashTableOpTest(test.TestCase): @@ -299,1240 +293,6 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([b"brain", b"salad", b"n/a"], result) -class MutableHashTableOpTest(test.TestCase): - - def testMutableHashTable(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant(["tarkus", "tank"]) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - - result = output.eval() - self.assertAllEqual([0, 1, -1], result) - - exported_keys, exported_values = table.export() - self.assertAllEqual([None], exported_keys.get_shape().as_list()) - self.assertAllEqual([None], exported_values.get_shape().as_list()) - - # exported data is in the order of the internal map, i.e. undefined - sorted_keys = np.sort(exported_keys.eval()) - sorted_values = np.sort(exported_values.eval()) - self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) - self.assertAllEqual([0, 1, 2], sorted_values) - - def testSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "save_restore") - save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - with self.session(graph=ops.Graph()) as sess: - v0 = variables.Variable(10.0, name="v0") - v1 = variables.Variable(20.0, name="v1") - - default_val = -1 - keys = constant_op.constant(["b", "c", "d"], dtypes.string) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable( - dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) - - save = saver.Saver() - variables.global_variables_initializer().run() - - # Check that the parameter nodes have been initialized. - self.assertEqual(10.0, v0.eval()) - self.assertEqual(20.0, v1.eval()) - - self.assertAllEqual(0, table.size().eval()) - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - val = save.save(sess, save_path) - self.assertTrue(isinstance(val, six.string_types)) - self.assertEqual(save_path, val) - - with self.session(graph=ops.Graph()) as sess: - v0 = variables.Variable(-1.0, name="v0") - v1 = variables.Variable(-1.0, name="v1") - default_val = -1 - table = lookup.MutableHashTable( - dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) - table.insert( - constant_op.constant(["a", "c"], dtypes.string), - constant_op.constant([12, 24], dtypes.int64)).run() - self.assertAllEqual(2, table.size().eval()) - - save = saver.Saver() - - # Restore the saved values in the parameter nodes. - save.restore(sess, save_path) - # Check that the parameter nodes have been restored. - self.assertEqual(10.0, v0.eval()) - self.assertEqual(20.0, v1.eval()) - - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["a", "b", "c", "d", "e"], - dtypes.string) - output = table.lookup(input_string) - self.assertAllEqual([-1, 0, 1, 2, -1], output.eval()) - - @test_util.run_in_graph_and_eager_modes - def testObjectSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "save_restore") - save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - v0 = variables.Variable(10.0, name="v0") - v1 = variables.Variable(20.0, name="v1") - - default_val = -1 - keys = constant_op.constant(["b", "c", "d"], dtypes.string) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable( - dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) - - checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1) - self.evaluate([v0.initializer, v1.initializer]) - - # Check that the parameter nodes have been initialized. - self.assertEqual(10.0, self.evaluate(v0)) - self.assertEqual(20.0, self.evaluate(v1)) - - self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) - - save_path = checkpoint.save(save_prefix) - del table, checkpoint, v0, v1 - - v0 = variables.Variable(-1.0, name="v0") - v1 = variables.Variable(-1.0, name="v1") - default_val = -1 - table = lookup.MutableHashTable( - dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) - self.evaluate(table.insert( - constant_op.constant(["a", "c"], dtypes.string), - constant_op.constant([12, 24], dtypes.int64))) - self.assertAllEqual(2, self.evaluate(table.size())) - - checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1) - - # Restore the saved values in the parameter nodes. - checkpoint.restore(save_path).run_restore_ops() - # Check that the parameter nodes have been restored. - self.assertEqual(10.0, self.evaluate(v0)) - self.assertEqual(20.0, self.evaluate(v1)) - - self.assertAllEqual(3, self.evaluate(table.size())) - - input_string = constant_op.constant(["a", "b", "c", "d", "e"], - dtypes.string) - output = table.lookup(input_string) - self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output)) - - def testSharing(self): - # Start a server to store the table state - server = server_lib.Server( - { - "local0": ["localhost:0"] - }, protocol="grpc", start=True) - # Create two sessions sharing the same state - session1 = session.Session(server.target) - session2 = session.Session(server.target) - - table = lookup.MutableHashTable( - dtypes.int64, dtypes.string, "-", name="t1") - - # Populate the table in the first session - with session1: - self.assertAllEqual(0, table.size().eval()) - - keys = constant_op.constant([11, 12], dtypes.int64) - values = constant_op.constant(["a", "b"]) - table.insert(keys, values).run() - self.assertAllEqual(2, table.size().eval()) - - output = table.lookup(constant_op.constant([11, 12, 13], dtypes.int64)) - self.assertAllEqual([b"a", b"b", b"-"], output.eval()) - - # Verify that we can access the shared data from the second session - with session2: - self.assertAllEqual(2, table.size().eval()) - - output = table.lookup(constant_op.constant([10, 11, 12], dtypes.int64)) - self.assertAllEqual([b"-", b"a", b"b"], output.eval()) - - def testMutableHashTableOfTensors(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) - values = constant_op.constant([[0, 1], [2, 3], [4, 5], [6, 7]], - dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant(["tarkus", "tank"]) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - self.assertAllEqual([3, 2], output.get_shape()) - - result = output.eval() - self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result) - - exported_keys, exported_values = table.export() - self.assertAllEqual([None], exported_keys.get_shape().as_list(), - msg="Saw shape %s" % exported_keys.shape) - self.assertAllEqual([None, 2], exported_values.get_shape().as_list(), - msg="Saw shape %s" % exported_values.shape) - # exported data is in the order of the internal map, i.e. undefined - sorted_keys = np.sort(exported_keys.eval()) - sorted_values = np.sort(exported_values.eval()) - self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) - self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values) - - def testMutableHashTableExportInsert(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) - table1 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table1.size().eval()) - table1.insert(keys, values).run() - self.assertAllEqual(3, table1.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - expected_output = [[0, 1], [2, 3], [-1, -1]] - output1 = table1.lookup(input_string) - self.assertAllEqual(expected_output, output1.eval()) - - exported_keys, exported_values = table1.export() - self.assertAllEqual(3, exported_keys.eval().size) - self.assertAllEqual(6, exported_values.eval().size) - - # Populate a second table from the exported data - table2 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table2.size().eval()) - table2.insert(exported_keys, exported_values).run() - self.assertAllEqual(3, table2.size().eval()) - - # Verify lookup result is still the same - output2 = table2.lookup(input_string) - self.assertAllEqual(expected_output, output2.eval()) - - def testMutableHashTableOfTensorsInvalidShape(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - # Shape [6] instead of [3, 2] - values = constant_op.constant([0, 1, 2, 3, 4, 5], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - table.insert(keys, values).run() - - # Shape [2,3] instead of [3, 2] - values = constant_op.constant([[0, 1, 2], [3, 4, 5]], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - table.insert(keys, values).run() - - # Shape [2, 2] instead of [3, 2] - values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - table.insert(keys, values).run() - - # Shape [3, 1] instead of [3, 2] - values = constant_op.constant([[0], [2], [4]], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - table.insert(keys, values).run() - - # Valid Insert - values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - def testMutableHashTableInvalidDefaultValue(self): - with self.cached_session(): - default_val = constant_op.constant([[-1, -1]], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - with self.assertRaisesOpError("Default value must be a vector"): - self.assertAllEqual(0, table.size().eval()) - - def testMutableHashTableDuplicateInsert(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery", "brain"]) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([3, 1, -1], result) - - def testMutableHashTableFindHighRank(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant( - [["brain", "salad"], ["tank", "tarkus"]]) - output = table.lookup(input_string) - self.assertAllEqual([2, 2], output.get_shape()) - - result = output.eval() - self.assertAllEqual([[0, 1], [-1, -1]], result) - - def testMutableHashTableInsertHighRank(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) - values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([0, 1, 3, -1], result) - - def testMutableHashTableRemoveHighRank(self): - with self.test_session(): - default_val = -1 - keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) - values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant(["salad", "tarkus"]) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([0, -1, 3, -1], result) - - def testMutableHashTableOfTensorsFindHighRank(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], - dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant( - [["brain", "salad"], ["tank", "tarkus"]]) - output = table.lookup(input_string) - self.assertAllEqual([2, 2, 3], output.get_shape()) - - result = output.eval() - self.assertAllEqual( - [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result) - - def testMutableHashTableOfTensorsRemoveHighRank(self): - with self.test_session(): - default_val = constant_op.constant([-1, -1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], - dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - remove_string = constant_op.constant([["brain", "tank"]]) - table.remove(remove_string).run() - self.assertAllEqual(2, table.size().eval()) - - input_string = constant_op.constant([["brain", "salad"], - ["surgery", "tank"]]) - output = table.lookup(input_string) - self.assertAllEqual([2, 2, 3], output.get_shape()) - - result = output.eval() - self.assertAllEqual( - [[[-1, -1, -1], [2, 3, 4]], [[4, 5, 6], [-1, -1, -1]]], result) - - def testMultipleMutableHashTables(self): - with self.cached_session() as sess: - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - - table1 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - table2 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - table3 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - table1.insert(keys, values).run() - table2.insert(keys, values).run() - table3.insert(keys, values).run() - - self.assertAllEqual(3, table1.size().eval()) - self.assertAllEqual(3, table2.size().eval()) - self.assertAllEqual(3, table3.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output1 = table1.lookup(input_string) - output2 = table2.lookup(input_string) - output3 = table3.lookup(input_string) - - out1, out2, out3 = sess.run([output1, output2, output3]) - self.assertAllEqual([0, 1, -1], out1) - self.assertAllEqual([0, 1, -1], out2) - self.assertAllEqual([0, 1, -1], out3) - - def testMutableHashTableWithTensorDefault(self): - with self.cached_session(): - default_val = constant_op.constant(-1, dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([0, 1, -1], result) - - def testSignatureMismatch(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - # insert with keys of the wrong type - with self.assertRaises(ValueError): - table.insert(constant_op.constant([4, 5, 6]), values).run() - - # insert with values of the wrong type - with self.assertRaises(ValueError): - table.insert(keys, constant_op.constant(["a", "b", "c"])).run() - - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string_ref = variables.Variable("brain") - input_int64_ref = variables.Variable(-1, dtype=dtypes.int64) - variables.global_variables_initializer().run() - - # Ref types do not produce an insert signature mismatch. - table.insert(input_string_ref, input_int64_ref).run() - self.assertAllEqual(3, table.size().eval()) - - # Ref types do not produce a lookup signature mismatch. - self.assertEqual(-1, table.lookup(input_string_ref).eval()) - - # lookup with keys of the wrong type - input_string = constant_op.constant([1, 2, 3], dtypes.int64) - with self.assertRaises(ValueError): - table.lookup(input_string).eval() - - # default value of the wrong type - with self.assertRaises(TypeError): - lookup.MutableHashTable(dtypes.string, dtypes.int64, "UNK") - - def testMutableHashTableStringFloat(self): - with self.cached_session(): - default_val = -1.5 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1.1, 2.2], dtypes.float32) - table = lookup.MutableHashTable(dtypes.string, dtypes.float32, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllClose([0, 1.1, default_val], result) - - def testMutableHashTableIntFloat(self): - with self.cached_session(): - default_val = -1.0 - keys = constant_op.constant([3, 7, 0], dtypes.int64) - values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32) - table = lookup.MutableHashTable(dtypes.int64, dtypes.float32, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([7, 0, 11], dtypes.int64) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllClose([-1.2, 9.9, default_val], result) - - def testMutableHashTableInt64String(self): - with self.cached_session(): - default_val = "n/a" - keys = constant_op.constant([0, 1, 2], dtypes.int64) - values = constant_op.constant(["brain", "salad", "surgery"]) - table = lookup.MutableHashTable(dtypes.int64, dtypes.string, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([0, 1, 3], dtypes.int64) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual((b"brain", b"salad", b"n/a"), result) - - -class MutableDenseHashTableOpTest(test.TestCase): - - def testBasic(self): - with self.cached_session(): - - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant([12, 15], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([11, 12, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - - result = output.eval() - self.assertAllEqual([0, -1, -1], result) - - def testBasicBool(self): - with self.cached_session(): - - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([True, True, True, True], dtypes.bool) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.bool, - default_value=False, - empty_key=0, - deleted_key=-1) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant([11, 15], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([11, 12, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - - result = output.eval() - self.assertAllEqual([False, True, False], result) - - def testSameEmptyAndDeletedKey(self): - with self.cached_session(): - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "deleted_key"): - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=42, - deleted_key=42) - self.assertAllEqual(0, table.size().eval()) - - def testLookupUnknownShape(self): - with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - placeholder_keys = array_ops.placeholder(dtypes.int64) - output = table.lookup(placeholder_keys) - self.assertAllEqual(None, output.get_shape()) - result = output.eval({placeholder_keys: [11, 12, 15]}) - self.assertAllEqual([0, 1, -1], result) - - def testMapStringToFloat(self): - with self.cached_session(): - - keys = constant_op.constant(["a", "b", "c", "d"], dtypes.string) - values = constant_op.constant([0.0, 1.1, 2.2, 3.3], dtypes.float32) - default_value = constant_op.constant(-1.5, dtypes.float32) - table = lookup.MutableDenseHashTable( - dtypes.string, - dtypes.float32, - default_value=default_value, - empty_key="", - deleted_key="$") - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant(["b", "e"]) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["a", "b", "d", "e"], dtypes.string) - output = table.lookup(input_string) - self.assertAllEqual([4], output.get_shape()) - - result = output.eval() - self.assertAllClose([0, -1.5, 3.3, -1.5], result) - - def testMapInt64ToFloat(self): - for float_dtype in [dtypes.float32, dtypes.float64]: - with self.cached_session(): - - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([0.0, 1.1, 2.2, 3.3], float_dtype) - default_value = constant_op.constant(-1.5, float_dtype) - table = lookup.MutableDenseHashTable( - dtypes.int64, - float_dtype, - default_value=default_value, - empty_key=0, - deleted_key=-1) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant([12, 15], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([4], output.get_shape()) - - result = output.eval() - self.assertAllClose([0, -1.5, 3.3, -1.5], result) - - def testVectorValues(self): - with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]], - dtypes.int64) - default_value = constant_op.constant([-1, -2, -3, -4], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=0, - deleted_key=-1, - initial_num_buckets=4) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(4, len(table.export()[0].eval())) - - table.insert( - constant_op.constant([14], dtypes.int64), - constant_op.constant([[2, 3, 4, 5]], dtypes.int64)).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(8, len(table.export()[0].eval())) - - remove_string = constant_op.constant([12, 16], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(8, len(table.export()[0].eval())) - - input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([4, 4], - output.shape, - msg="Saw shape: %s" % output.shape) - - result = output.eval() - self.assertAllEqual( - [[0, 1, 2, 3], [-1, -2, -3, -4], [2, 3, 4, 5], [-1, -2, -3, -4]], - result) - - def testVectorKeys(self): - with self.cached_session(): - keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64) - values = constant_op.constant([10, 11, 12], dtypes.int64) - empty_key = constant_op.constant([0, 3], dtypes.int64) - deleted_key = constant_op.constant([-1, -1], dtypes.int64) - default_value = constant_op.constant(-1, dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - initial_num_buckets=8) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - table.insert( - constant_op.constant([[0, 0]], dtypes.int64), - constant_op.constant([13], dtypes.int64)).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(8, len(table.export()[0].eval())) - - remove_string = constant_op.constant([[1, 2], [7, 8]], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(8, len(table.export()[0].eval())) - - input_string = constant_op.constant([[0, 1], [1, 2], [1, 3], [0, 2]], - dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([4], output.get_shape()) - - result = output.eval() - self.assertAllEqual([10, -1, 12, -1], result) - - def testResize(self): - with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1, - initial_num_buckets=4) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(4, len(table.export()[0].eval())) - - keys2 = constant_op.constant([12, 99], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(2, table.size().eval()) - self.assertAllEqual(4, len(table.export()[0].eval())) - - keys3 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64) - values3 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64) - - table.insert(keys3, values3).run() - self.assertAllEqual(6, table.size().eval()) - self.assertAllEqual(16, len(table.export()[0].eval())) - - keys4 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18], - dtypes.int64) - output = table.lookup(keys4) - self.assertAllEqual([-1, 0, -1, 3, 4, 5, 6, 7, -1], output.eval()) - - def testExport(self): - with self.cached_session(): - - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([1, 2, 3, 4], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=100, - deleted_key=200, - initial_num_buckets=8) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - keys2 = constant_op.constant([12, 15], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(3, table.size().eval()) - - exported_keys, exported_values = table.export() - self.assertAllEqual([None], exported_keys.get_shape().as_list()) - self.assertAllEqual([None], exported_values.get_shape().as_list()) - - np_keys = exported_keys.eval() - np_values = exported_values.eval() - - self.assertAllEqual(8, len(np_keys)) - self.assertAllEqual(8, len(np_values)) - - # pair up keys and values, drop extra added dimension - pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0] - # sort by key - pairs = pairs[pairs[:, 0].argsort()] - self.assertAllEqual([[11, 1], [13, 3], [14, 4], [100, 0], [100, 0], - [100, 0], [100, 0], [200, 2]], pairs) - - def testSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "save_restore") - save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - with self.session(graph=ops.Graph()) as sess: - default_value = -1 - empty_key = 0 - deleted_key = -1 - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=32) - - save = saver.Saver() - - self.assertAllEqual(0, table.size().eval()) - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - keys2 = constant_op.constant([12, 15], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - val = save.save(sess, save_path) - self.assertTrue(isinstance(val, six.string_types)) - self.assertEqual(save_path, val) - - with self.session(graph=ops.Graph()) as sess: - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=64) - table.insert( - constant_op.constant([11, 14], dtypes.int64), - constant_op.constant([12, 24], dtypes.int64)).run() - self.assertAllEqual(2, table.size().eval()) - self.assertAllEqual(64, len(table.export()[0].eval())) - - save = saver.Saver() - - # Restore the saved values in the parameter nodes. - save.restore(sess, save_path) - - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([-1, 0, -1, 2, 3], output.eval()) - - @test_util.run_in_graph_and_eager_modes - def testObjectSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "save_restore") - save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - default_value = -1 - empty_key = 0 - deleted_key = -1 - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) - save_table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=32) - - save_checkpoint = checkpointable.Checkpoint(table=save_table) - - self.assertAllEqual(0, self.evaluate(save_table.size())) - self.evaluate(save_table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(save_table.size())) - self.assertAllEqual(32, len(self.evaluate(save_table.export()[0]))) - - save_path = save_checkpoint.save(save_prefix) - del save_table, save_checkpoint - - load_table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=64) - self.evaluate(load_table.insert( - constant_op.constant([11, 14], dtypes.int64), - constant_op.constant([12, 24], dtypes.int64))) - self.assertAllEqual(2, self.evaluate(load_table.size())) - self.assertAllEqual(64, len(self.evaluate(load_table.export()[0]))) - - restore_checkpoint = checkpointable.Checkpoint(table=load_table) - - # Restore the saved values in the parameter nodes. - restore_checkpoint.restore(save_path).run_restore_ops() - - self.assertAllEqual(3, self.evaluate(load_table.size())) - self.assertAllEqual(32, len(self.evaluate(load_table.export()[0]))) - - input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64) - output = load_table.lookup(input_string) - self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output)) - - def testVectorSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore") - save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - with self.session(graph=ops.Graph()) as sess: - empty_key = constant_op.constant([11, 13], dtypes.int64) - deleted_key = constant_op.constant([-2, -3], dtypes.int64) - default_value = constant_op.constant([-1, -2], dtypes.int64) - keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]], - dtypes.int64) - values = constant_op.constant([[0, 1], [2, 3], [2, 4], [4, 5]], - dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=32) - - save = saver.Saver() - - self.assertAllEqual(0, table.size().eval()) - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - keys2 = constant_op.constant([[12, 13], [16, 17]], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - val = save.save(sess, save_path) - self.assertTrue(isinstance(val, six.string_types)) - self.assertEqual(save_path, val) - - with self.session(graph=ops.Graph()) as sess: - empty_key = constant_op.constant([11, 13], dtypes.int64) - deleted_key = constant_op.constant([-2, -3], dtypes.int64) - default_value = constant_op.constant([-1, -2], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=64) - table.insert( - constant_op.constant([[11, 12], [13, 15]], dtypes.int64), - constant_op.constant([[21, 22], [23, 24]], dtypes.int64)).run() - self.assertAllEqual(2, table.size().eval()) - self.assertAllEqual(64, len(table.export()[0].eval())) - - save = saver.Saver() - - # Restore the saved values in the parameter nodes. - save.restore(sess, save_path) - - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - input_string = constant_op.constant( - [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([[0, 1], [2, 3], [-1, -2], [4, 5], [-1, -2]], - output.eval()) - - def testVectorScalarSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "vector_scalar_save_restore") - save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - with self.session(graph=ops.Graph()) as sess: - empty_key = constant_op.constant([11, 13], dtypes.int64) - deleted_key = constant_op.constant([-1, -1], dtypes.int64) - default_value = constant_op.constant(-1, dtypes.int64) - keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]], - dtypes.int64) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t2", - checkpoint=True, - initial_num_buckets=32) - - save = saver.Saver() - - self.assertAllEqual(0, table.size().eval()) - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - keys2 = constant_op.constant([[12, 13], [15, 16]], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - val = save.save(sess, save_path) - self.assertTrue(isinstance(val, six.string_types)) - self.assertEqual(save_path, val) - - with self.session(graph=ops.Graph()) as sess: - empty_key = constant_op.constant([11, 13], dtypes.int64) - deleted_key = constant_op.constant([-1, -1], dtypes.int64) - default_value = constant_op.constant(-1, dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t2", - checkpoint=True, - initial_num_buckets=64) - table.insert( - constant_op.constant([[11, 12], [13, 15]], dtypes.int64), - constant_op.constant([3, 4], dtypes.int64)).run() - self.assertAllEqual(2, table.size().eval()) - self.assertAllEqual(64, len(table.export()[0].eval())) - - save = saver.Saver() - - # Restore the saved values in the parameter nodes. - save.restore(sess, save_path) - - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - input_string = constant_op.constant( - [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([0, 1, -1, 3, -1], output.eval()) - - def testReprobe(self): - with self.cached_session(): - # Insert 6 keys into a table with 8 buckets. - # The values are chosen to make sure collisions occur when using GCC STL - keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64) - values = constant_op.constant([51, 52, 53, 54, 55, 56], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1, - initial_num_buckets=8) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(6, table.size().eval()) - - input_string = constant_op.constant([10, 11, 12, 13, 14, 19, 20, 21, 22], - dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([9], output.get_shape()) - - result = output.eval() - self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result) - - def testCustomEmptyKey(self): - with self.cached_session(): - keys = constant_op.constant([11, 0, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=12, - deleted_key=-1) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([11, 0, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - - result = output.eval() - self.assertAllEqual([0, 1, -1], result) - - def testErrors(self): - with self.cached_session(): - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1) - - # Inserting the empty key returns an error - keys1 = constant_op.constant([11, 0], dtypes.int64) - values1 = constant_op.constant([0, 1], dtypes.int64) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "empty_key"): - table.insert(keys1, values1).run() - - # Looking up the empty key returns an error - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "empty_key"): - table.lookup(keys1).eval() - - # Inserting the deleted key returns an error - keys2 = constant_op.constant([11, -1], dtypes.int64) - values2 = constant_op.constant([0, 1], dtypes.int64) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "deleted_key"): - table.insert(keys2, values2).run() - - # Looking up the empty key returns an error - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "deleted_key"): - table.lookup(keys2).eval() - - # Arbitrary tensors of keys are not supported - keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) - values = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Expected key shape"): - table.lookup(keys).eval() - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Expected key shape"): - table.insert(keys, values).run() - - table2 = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=17, - deleted_key=-1, - initial_num_buckets=12) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Number of buckets must be"): - self.assertAllEqual(0, table2.size().eval()) - - with self.assertRaisesRegexp( - errors_impl.InvalidArgumentError, - "Empty and deleted keys must have same shape"): - table3 = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=42, - deleted_key=[1, 2]) - self.assertAllEqual(0, table3.size().eval()) - - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Empty and deleted keys cannot be equal"): - table4 = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=42, - deleted_key=42) - self.assertAllEqual(0, table4.size().eval()) - - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Empty and deleted keys cannot be equal"): - table5 = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=[1, 2, 3], - deleted_key=[1, 2, 3]) - self.assertAllEqual(0, table5.size().eval()) - - class IndexTableFromFile(test.TestCase): def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): @@ -2721,64 +1481,6 @@ class IdTableWithHashBucketsTest(test.TestCase): hasher_spec=lookup.StrongHashSpec([None, 2])) -class MutableHashTableBenchmark(test.Benchmark): - - def _create_table(self): - return lookup.MutableHashTable(dtypes.int64, dtypes.float32, 0.0) - - def benchmark_single_repeated_scalar_insert_scalar(self): - table = self._create_table() - value = variables.Variable(1.0) - insert = table.insert(0, value) - size = table.size() - with session.Session() as sess: - sess.run(value.initializer) - self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000) - assert sess.run(size) == 1 - - def benchmark_many_repeated_scalar_insert_scalar(self): - table = self._create_table() - c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next() - value = variables.Variable(1.0) - insert = table.insert(c, value) - size = table.size() - with session.Session() as sess: - sess.run(value.initializer) - self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000) - assert sess.run(size) >= 10000 - - def benchmark_single_repeated_batch_32_insert_scalar(self): - table = self._create_table() - value = variables.Variable([1.0] * 32) - insert = table.insert(list(range(32)), value) - size = table.size() - with session.Session() as sess: - sess.run(value.initializer) - self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000) - assert sess.run(size) == 32 - - def benchmark_many_repeated_batch_32_insert_scalar(self): - table = self._create_table() - c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next() - value = variables.Variable([1.0] * 32) - insert = table.insert(32 * c + list(range(32)), value) - size = table.size() - with session.Session() as sess: - sess.run(value.initializer) - self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000) - assert sess.run(size) >= 1000*32 - - -class MutableDenseHashTableBenchmark(MutableHashTableBenchmark): - - def _create_table(self): - return lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.float32, - default_value=0.0, - empty_key=-1, - deleted_key=-2) - - if __name__ == "__main__": test.main() + diff --git a/tensorflow/contrib/losses/BUILD b/tensorflow/contrib/losses/BUILD index 728f75f8ef1eb3b107dbd0ab4ffbecd63787bf3e..f4ebbdeee883ddeef0d47cb561901c16e2195bb2 100644 --- a/tensorflow/contrib/losses/BUILD +++ b/tensorflow/contrib/losses/BUILD @@ -82,10 +82,11 @@ py_library( py_test( name = "metric_loss_ops_test", - size = "large", + size = "medium", srcs = [ "python/metric_learning/metric_loss_ops_test.py", ], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":metric_learning_py", diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 709a042bbcefb89125f7e4cd14a0d7ecd2b53281..5ebdd0b8b50063c99e6b747c594eb99c306b4efb 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -511,7 +511,7 @@ def mean_squared_error(predictions, labels=None, weights=1.0, scope=None): predictions.get_shape().assert_is_compatible_with(labels.get_shape()) predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) - losses = math_ops.square(math_ops.subtract(predictions, labels)) + losses = math_ops.squared_difference(predictions, labels) return compute_weighted_loss(losses, weights, scope=scope) diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py index de76acb51ffe985162a66c617b266f47c5216b19..f3b0e77740ff1d940fcd6d00b3482e90f6ebf952 100644 --- a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py +++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py @@ -105,7 +105,8 @@ def contrastive_loss(labels, embeddings_anchor, embeddings_positive, # Get per pair distances distances = math_ops.sqrt( math_ops.reduce_sum( - math_ops.square(embeddings_anchor - embeddings_positive), 1)) + math_ops.squared_difference(embeddings_anchor, embeddings_positive), + 1)) # Add contrastive loss for the siamese network. # label here is {0,1} for neg, pos. diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 2a5232b476712a96f84be0f4725beb78bc138297..af3c541dc214c30e9e59fdcca995ffc53b028df4 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -142,5 +142,6 @@ replace_by_sed 's#static uint64x2_t p2ul_CONJ_XOR = vld1q_u64( p2ul_conj_XOR_DAT # TODO(satok): Remove this once protobuf/autogen.sh is fixed. replace_by_sed 's#https://googlemock.googlecode.com/files/gmock-1.7.0.zip#http://download.tensorflow.org/deps/gmock-1.7.0.zip#' \ "${DOWNLOADS_DIR}/protobuf/autogen.sh" +cat "third_party/eigen3/gebp_neon.patch" | patch "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h" echo "download_dependencies.sh completed successfully." >&2 diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt index 9ea94c74330e3e49414a6a84cd5bc0db3778114a..0a0ba36232075460b561bc54a95fc24973017571 100644 --- a/tensorflow/contrib/makefile/proto_text_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt @@ -40,7 +40,6 @@ tensorflow/core/lib/wav/wav_io.cc tensorflow/core/platform/cpu_info.cc tensorflow/core/platform/default/logging.cc tensorflow/core/platform/default/mutex.cc -tensorflow/core/platform/default/protobuf.cc tensorflow/core/platform/default/tracing.cc tensorflow/core/platform/denormal.cc tensorflow/core/platform/env.cc @@ -53,6 +52,7 @@ tensorflow/core/platform/posix/error.cc tensorflow/core/platform/posix/load_library.cc tensorflow/core/platform/posix/port.cc tensorflow/core/platform/posix/posix_file_system.cc +tensorflow/core/platform/protobuf.cc tensorflow/core/platform/protobuf_util.cc tensorflow/core/platform/setround.cc tensorflow/core/platform/tensor_coding.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index 87c73ec1ca610cac6d63468887bc350bada5910b..1c1460ce77c99d29785c7e8b8a8e9f770a45b59f 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -25,6 +25,7 @@ tensorflow/core/framework/variable.pb.cc tensorflow/core/framework/versions.pb.cc tensorflow/core/grappler/costs/op_performance_data.pb.cc tensorflow/core/lib/core/error_codes.pb.cc +tensorflow/core/protobuf/trackable_object_graph.pb.cc tensorflow/core/protobuf/cluster.pb.cc tensorflow/core/protobuf/config.pb.cc tensorflow/core/protobuf/eager_service.pb.cc @@ -34,8 +35,11 @@ tensorflow/core/protobuf/meta_graph.pb.cc tensorflow/core/protobuf/named_tensor.pb.cc tensorflow/core/protobuf/queue_runner.pb.cc tensorflow/core/protobuf/rewriter_config.pb.cc +tensorflow/core/protobuf/saved_object_graph.pb.cc tensorflow/core/protobuf/saver.pb.cc +tensorflow/core/protobuf/struct.pb.cc tensorflow/core/protobuf/tensorflow_server.pb.cc +tensorflow/core/protobuf/verifier_config.pb.cc tensorflow/core/util/event.pb.cc tensorflow/core/util/memmapped_file_system.pb.cc tensorflow/core/util/saved_tensor_slice.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index 4120ea52ec5255b1efce7a6ce6890fc79c1e4831..5def632e8a7b65272a1339bdacd92c1fa23012d2 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -25,6 +25,7 @@ tensorflow/core/framework/variable.pb.h tensorflow/core/framework/versions.pb.h tensorflow/core/grappler/costs/op_performance_data.pb.h tensorflow/core/lib/core/error_codes.pb.h +tensorflow/core/protobuf/trackable_object_graph.pb.h tensorflow/core/protobuf/cluster.pb.h tensorflow/core/protobuf/config.pb.h tensorflow/core/protobuf/debug.pb.h @@ -34,9 +35,12 @@ tensorflow/core/protobuf/meta_graph.pb.h tensorflow/core/protobuf/named_tensor.pb.h tensorflow/core/protobuf/queue_runner.pb.h tensorflow/core/protobuf/rewriter_config.pb.h +tensorflow/core/protobuf/saved_object_graph.pb.h tensorflow/core/protobuf/saver.pb.h +tensorflow/core/protobuf/struct.pb.h tensorflow/core/protobuf/tensor_bundle.pb.h tensorflow/core/protobuf/tensorflow_server.pb.h +tensorflow/core/protobuf/verifier_config.pb.h tensorflow/core/util/event.pb.h tensorflow/core/util/memmapped_file_system.pb.h tensorflow/core/util/saved_tensor_slice.pb.h diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 655c7eefcb978d40c8bc16a23685e03ed71bfb63..2cd7d6d519a55423a96526b541845392d9ec6bc2 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -119,6 +119,7 @@ tensorflow/core/kernels/fake_quant_ops.cc tensorflow/core/kernels/fifo_queue.cc tensorflow/core/kernels/fifo_queue_op.cc tensorflow/core/kernels/fill_functor.cc +tensorflow/core/kernels/fft_ops.cc tensorflow/core/kernels/function_ops.cc tensorflow/core/kernels/fused_batch_norm_op.cc tensorflow/core/kernels/gather_functor.cc diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt index f94d70db9046cec43073ab1406762aea1f28c8e3..13e3b6422d1989b0d499d8d20901d919554c630e 100644 --- a/tensorflow/contrib/makefile/tf_pb_text_files.txt +++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt @@ -29,5 +29,6 @@ tensorflow/core/protobuf/debug.pb_text.cc tensorflow/core/protobuf/rewriter_config.pb_text.cc tensorflow/core/protobuf/saver.pb_text.cc tensorflow/core/protobuf/tensor_bundle.pb_text.cc +tensorflow/core/protobuf/verifier_config.pb_text.cc tensorflow/core/util/memmapped_file_system.pb_text.cc tensorflow/core/util/saved_tensor_slice.pb_text.cc diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index 2712e906d719e72dacb60e213205ad68895f905f..deb6a5b94020a02b878bdd68a33b3737a97fcf2b 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -31,6 +31,7 @@ tensorflow/core/framework/versions.proto tensorflow/core/grappler/costs/op_performance_data.proto tensorflow/core/kernels/boosted_trees/boosted_trees.proto tensorflow/core/lib/core/error_codes.proto +tensorflow/core/protobuf/trackable_object_graph.proto tensorflow/core/protobuf/cluster.proto tensorflow/core/protobuf/config.proto tensorflow/core/protobuf/debug.proto @@ -40,9 +41,12 @@ tensorflow/core/protobuf/meta_graph.proto tensorflow/core/protobuf/named_tensor.proto tensorflow/core/protobuf/queue_runner.proto tensorflow/core/protobuf/rewriter_config.proto +tensorflow/core/protobuf/saved_object_graph.proto tensorflow/core/protobuf/saver.proto +tensorflow/core/protobuf/struct.proto tensorflow/core/protobuf/tensor_bundle.proto tensorflow/core/protobuf/tensorflow_server.proto +tensorflow/core/protobuf/verifier_config.proto tensorflow/core/util/event.proto tensorflow/core/util/memmapped_file_system.proto tensorflow/core/util/saved_tensor_slice.proto diff --git a/tensorflow/contrib/memory_stats/BUILD b/tensorflow/contrib/memory_stats/BUILD index 63843b993c16363a80b64622af665aaa64e05830..93701249cc8bf722c8c8558e91e0b700ca1c4a04 100644 --- a/tensorflow/contrib/memory_stats/BUILD +++ b/tensorflow/contrib/memory_stats/BUILD @@ -10,6 +10,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -45,6 +46,28 @@ tf_gen_op_wrapper_py( deps = [":memory_stats_ops_op_lib"], ) +tf_gen_op_wrapper_cc( + name = "memory_stats_ops", + out_ops_file = "memory_stats_ops", +) + +cc_library( + name = "memory_stats_cc", + srcs = ["memory_stats_ops.cc"], + hdrs = ["memory_stats_ops.h"], + visibility = ["//visibility:public"], + deps = [ + ":memory_stats_kernels", + ":memory_stats_ops_op_lib", + "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + tf_custom_op_py_library( name = "memory_stats_py", srcs = [ diff --git a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc index 974fb537499c5ea4591a0a128f53d2dea67b9e57..7ae1dbeaa2d04d7846e7fada117f3941319cc1c1 100644 --- a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc +++ b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc @@ -24,13 +24,15 @@ class MemoryStatsOp : public OpKernel { void Compute(OpKernelContext* context) override { Allocator* allocator = context->device()->GetAllocator(AllocatorAttributes()); - AllocatorStats allocator_stats; - allocator->GetStats(&allocator_stats); + absl::optional allocator_stats = allocator->GetStats(); + if (!allocator_stats) { + *allocator_stats = AllocatorStats(); + } Tensor* output_tensor = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, TensorShape({}), &output_tensor)); - output_tensor->scalar()() = ExtractAllocatorStats(allocator_stats); + output_tensor->scalar()() = ExtractAllocatorStats(*allocator_stats); } protected: @@ -71,7 +73,7 @@ class BytesLimitOp : public MemoryStatsOp { private: int64 ExtractAllocatorStats( const AllocatorStats& allocator_stats) const override { - return allocator_stats.bytes_limit; + return allocator_stats.bytes_limit ? *allocator_stats.bytes_limit : -1; } }; @@ -93,7 +95,7 @@ class MaxBytesInUseOp : public MemoryStatsOp { private: int64 ExtractAllocatorStats( const AllocatorStats& allocator_stats) const override { - return allocator_stats.max_bytes_in_use; + return allocator_stats.peak_bytes_in_use; } }; diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 7b432f8bd20989c6d95310bcaca88d44ce3e0d1f..ece246b7c28569a551f7733daf16ee1507f9c95d 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1356,9 +1356,8 @@ def _compute_placement_auc(labels, predictions, weights, alpha, weights_0 * math_ops.square(1. - placement_values_0 - auc_0)) / (total_0 - 1. + _EPSILON)) var_1 = ( - math_ops.reduce_sum( - weights_1 * math_ops.square(placement_values_1 - auc_1)) / - (total_1 - 1. + _EPSILON)) + math_ops.reduce_sum(weights_1 * math_ops.squared_difference( + placement_values_1, auc_1)) / (total_1 - 1. + _EPSILON)) auc_std_err = math_ops.sqrt( (var_0 / (total_0 + _EPSILON)) + (var_1 / (total_1 + _EPSILON))) diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index 45a60d79482787df4564ae3360f8252af93c7a26..710a262f33872ada8d090d796f80dc06c2a27f84 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -53,7 +53,6 @@ The pruning library allows for specification of the following hyper parameters: | weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. | | threshold_decay | float | 0.0 | The decay factor to use for exponential decay of the thresholds | | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | -| nbins | integer | 256 | Number of bins to use for histogram computation. Note: When running on TPUs, a large (>1024) value for `nbins` may adversely affect the training time. | | block_height|integer | 1 | Number of rows in a block for block sparse matrices| | block_width |integer | 1 | Number of cols in a block for block sparse matrices| | block_pooling_function| string | AVG | The function to use to pool weight values in a block: average (AVG) or max (MAX)| diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index f6b4373edd0544555dd16a373802d2feb5d674b1..9966f7cf798d206fffbaeb4d16b6500a90d113e4 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -214,7 +214,7 @@ def get_pruning_hparams(): target_sparsity=0.5, sparsity_function_begin_step=0, sparsity_function_end_step=100, - sparsity_function_exponent=3, + sparsity_function_exponent=3.0, use_tpu=False) @@ -397,28 +397,26 @@ class Pruning(object): raise ValueError('Sparsity variable undefined') sparsity = self._get_sparsity(weights.op.name) - with ops.name_scope(weights.op.name + '_pruning_ops'): abs_weights = math_ops.abs(weights) - max_value = math_ops.reduce_max(abs_weights) - cdf_fn = pruning_utils.compute_cdf_from_histogram - if self._spec.use_tpu: - cdf_fn = pruning_utils.compute_cdf - - norm_cdf = cdf_fn(abs_weights, [0.0, max_value], nbins=self._spec.nbins) - current_threshold = math_ops.multiply( - math_ops.div( - math_ops.reduce_sum( - math_ops.cast( - math_ops.less(norm_cdf, sparsity), dtypes.float32)), - float(self._spec.nbins)), max_value) - + k = math_ops.cast( + math_ops.round( + math_ops.cast(array_ops.size(abs_weights), dtypes.float32) * + (1 - sparsity)), dtypes.int32) + # Sort the entire array + values, _ = nn_ops.top_k( + array_ops.reshape(abs_weights, [-1]), k=array_ops.size(abs_weights)) + # Grab the (k-1) th value + current_threshold = array_ops.gather(values, k - 1) smoothed_threshold = math_ops.add_n([ math_ops.multiply(current_threshold, 1 - self._spec.threshold_decay), math_ops.multiply(threshold, self._spec.threshold_decay) ]) + new_mask = math_ops.cast( - math_ops.greater(abs_weights, smoothed_threshold), dtypes.float32) + math_ops.greater_equal(abs_weights, smoothed_threshold), + dtypes.float32) + return smoothed_threshold, new_mask def _maybe_update_block_mask(self, weights, threshold): diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py index 1b6da5ce2b4ebb3ea3b204c4ed12bed8db951447..835614d8822147dadb029107ae0e917cc955eef0 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_test.py @@ -102,7 +102,7 @@ class PruningTest(test.TestCase): weights = variables.VariableV1( math_ops.linspace(1.0, 100.0, 100), name="weights") masked_weights = pruning.apply_mask(weights) - sparsity = variables.VariableV1(0.5, name="sparsity") + sparsity = variables.VariableV1(0.95, name="sparsity") p = pruning.Pruning(sparsity=sparsity) p._spec.threshold_decay = 0.0 mask_update_op = p.mask_update_op() @@ -111,7 +111,7 @@ class PruningTest(test.TestCase): self.assertAllEqual(np.count_nonzero(masked_weights_val), 100) session.run(mask_update_op) masked_weights_val = masked_weights.eval() - self.assertAllEqual(np.count_nonzero(masked_weights_val), 50) + self.assertAllEqual(np.count_nonzero(masked_weights_val), 5) def _blockMasking(self, hparams, weights, expected_mask): diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py index 14fc51229ab53a77e8089040e8a8576babd0fafd..8f2ba036469bd02328a831a3d1de2ffbd10f5004 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py @@ -25,16 +25,12 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope -_NBINS = 256 - def weight_mask_variable(var, scope): """Create a mask for the weights. @@ -165,128 +161,6 @@ def expand_tensor(tensor, block_dims): return expanded_tensor -def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None): - """Return histogram of values. - - Given the tensor `values`, this operation returns a rank 1 histogram counting - the number of entries in `values` that fell into every bin. The bins are - equal width and determined by the arguments `value_range` and `nbins`. - - Args: - values: Numeric `Tensor`. - value_range: Shape [2] `Tensor` of same `dtype` as `values`. - values <= value_range[0] will be mapped to hist[0], - values >= value_range[1] will be mapped to hist[-1]. - nbins: Scalar `int32 Tensor`. Number of histogram bins. - dtype: dtype for returned histogram. - name: A name for this operation (defaults to 'histogram'). - - Returns: - A 1-D `Tensor` holding histogram of values. - - """ - with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope: - values = ops.convert_to_tensor(values, name='values') - values = array_ops.reshape(values, [-1]) - nbins_float = np.float32(nbins) - - # Map tensor values that fall within value_range to [0, 1]. - scaled_values = math_ops.truediv( - values - value_range[0], - value_range[1] - value_range[0], - name='scaled_values') - - # map tensor values within the open interval value_range to {0,.., nbins-1}, - # values outside the open interval will be zero or less, or nbins or more. - indices = math_ops.floor(nbins_float * scaled_values, name='indices') - - # Clip edge cases (e.g. value = value_range[1]) or "outliers." - indices = math_ops.cast( - clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32) - - return math_ops.unsorted_segment_sum( - array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope) - - -def compute_cdf_from_histogram(values, value_range, **kwargs): - """Returns the normalized cumulative distribution of the given values tensor. - - Computes the histogram and uses tf.cumsum to arrive at cdf - - Args: - values: Numeric `Tensor`. - value_range: Shape [2] `Tensor` of same `dtype` as `values`. - **kwargs: keyword arguments: nbins, name - - Returns: - A 1-D `Tensor` holding normalized cdf of values. - - """ - nbins = kwargs.get('nbins', _NBINS) - name = kwargs.get('name', None) - with ops.name_scope(name, 'cdf', [values, value_range, nbins]): - histogram = _histogram( - values, value_range, dtype=dtypes.float32, nbins=nbins) - cdf = math_ops.cumsum(histogram) - return math_ops.div(cdf, math_ops.reduce_max(cdf)) - - -def compute_cdf(values, value_range, **kwargs): - """Returns the normalized cumulative distribution of the given values tensor. - - Uses tf.while_loop to directly compute the cdf of the values. - - Args: - values: Numeric `Tensor`. - value_range: Shape [2] `Tensor` of same `dtype` as `values` - **kwargs: keyword arguments: nbins, name - - Returns: - A 1-D `Tensor` holding normalized cdf of values. - - """ - nbins = kwargs.get('nbins', _NBINS) - name = kwargs.get('name', None) - with ops.name_scope(name, 'cdf', [values, value_range, nbins]): - values = ops.convert_to_tensor(values, name='values') - nbins_float = np.float32(nbins) - - # Map tensor values that fall within value_range to [0, 1]. - scaled_values = math_ops.truediv( - values - value_range[0], - value_range[1] - value_range[0], - name='scaled_values') - - # map tensor values within the open interval value_range to {0,.., nbins-1}, - # values outside the open interval will be zero or less, or nbins or more. - indices = math_ops.floor(nbins_float * scaled_values, name='indices') - - # Clip edge cases (e.g. value = value_range[1]) or "outliers." - indices = math_ops.cast( - clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32) - - cdf = array_ops.zeros(nbins) - i = constant_op.constant(0) - - def loop_cond(loop_count, _): - return math_ops.less(loop_count, nbins) - - def loop_body(loop_count, cdf): - temp = math_ops.reduce_sum( - math_ops.cast( - math_ops.less_equal(indices, loop_count), dtypes.float32)) - cdf = math_ops.add( - cdf, - array_ops.one_hot( - loop_count, depth=nbins, on_value=temp, off_value=0.0)) - return [loop_count + 1, cdf] - - _, cdf = control_flow_ops.while_loop( - loop_cond, loop_body, [i, cdf], maximum_iterations=nbins) - - return math_ops.div(cdf, math_ops.reduce_max(cdf)) - - def factorized_pool(input_tensor, window_shape, pooling_type, diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py index d6f2bfcb6c2e2beda912eb538d8a4a0a17b486b3..b85bc413155d53cd6d53e98dae0ad626531f61eb 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py @@ -19,13 +19,9 @@ from __future__ import division from __future__ import print_function from absl.testing import parameterized -import numpy as np from tensorflow.contrib.model_pruning.python import pruning_utils -from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope @@ -33,57 +29,6 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -class PruningUtilsTest(test.TestCase): - - def _compare_cdf(self, values): - abs_values = math_ops.abs(values) - max_value = math_ops.reduce_max(abs_values) - with self.cached_session(): - variables.global_variables_initializer().run() - cdf_from_histogram = pruning_utils.compute_cdf_from_histogram( - abs_values, [0.0, max_value], nbins=pruning_utils._NBINS) - cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value]) - self.assertAllEqual(cdf.eval(), cdf_from_histogram.eval()) - - def testHistogram(self): - width = 10 - height = 10 - nbins = 100 - expected_histogram = np.full(nbins, 1.0) - init = init_ops.constant_initializer(np.linspace(0.0, 1.0, width * height)) - weights = variable_scope.get_variable( - "weights", [width, height], initializer=init) - histogram = pruning_utils._histogram( - weights, [0, 1.0], nbins, dtype=np.float32) - with self.cached_session(): - variables.global_variables_initializer().run() - computed_histogram = histogram.eval() - self.assertAllEqual(expected_histogram, computed_histogram) - - def testCDF(self): - nbins = 5 - weights = constant_op.constant([-1, 0, 1, 1.5, 2, 3, 4, 5, 10, 100]) - abs_weights = math_ops.abs(weights) - norm_cdf = pruning_utils.compute_cdf_from_histogram( - abs_weights, [0.0, 5.0], nbins=nbins) - expected_cdf = np.array([0.1, 0.4, 0.5, 0.6, 1.0], dtype=np.float32) - with self.cached_session() as sess: - variables.global_variables_initializer().run() - norm_cdf_val = sess.run(norm_cdf) - self.assertAllEqual(len(norm_cdf_val), nbins) - self.assertAllEqual(expected_cdf, norm_cdf_val) - - def testCDFEquivalence2D(self): - width = 100 - height = 100 - weights = variable_scope.get_variable("weights", shape=[width, height]) - self._compare_cdf(weights) - - def testCDFEquivalence4D(self): - weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128]) - self._compare_cdf(weights) - - @parameterized.named_parameters( ("Input_32x32_block_1x1", [32, 32], [1, 1]), # block size 6x6 diff --git a/tensorflow/contrib/mpi/mpi_server_lib.cc b/tensorflow/contrib/mpi/mpi_server_lib.cc index a31fa9ce0b3110d875689d74a41ca9f9cc85f532..e44e10af0814ba8d6d964dfc34a0470ce45c0b40 100644 --- a/tensorflow/contrib/mpi/mpi_server_lib.cc +++ b/tensorflow/contrib/mpi/mpi_server_lib.cc @@ -54,7 +54,10 @@ MPIServer::~MPIServer() { Status MPIServer::Init(ServiceInitFunction service_func, RendezvousMgrCreationFunction rendezvous_mgr_func) { - Status s = GrpcServer::Init(service_func, rendezvous_mgr_func); + GrpcServerOptions opts; + opts.service_func = service_func; + opts.rendezvous_mgr_func = rendezvous_mgr_func; + Status s = GrpcServer::Init(opts); return s; } diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD index ecac06354d2ce796f2a6021cdf2370d7c30ccab7..a7be92a35e0d62a61f7923ac61bb2c1267d039c6 100644 --- a/tensorflow/contrib/mpi_collectives/BUILD +++ b/tensorflow/contrib/mpi_collectives/BUILD @@ -52,7 +52,6 @@ tf_custom_op_library( deps = [ ":mpi_defines", ":mpi_message_proto_cc", - "//tensorflow/stream_executor:stream_executor_headers_lib", "//third_party/mpi", ], ) diff --git a/tensorflow/contrib/mpi_collectives/ring.h b/tensorflow/contrib/mpi_collectives/ring.h index cae57ce60eb09509af69f8ccab9eacedea361548..9b5d52e1b648e62af93d5420885e4f22796e3ea1 100644 --- a/tensorflow/contrib/mpi_collectives/ring.h +++ b/tensorflow/contrib/mpi_collectives/ring.h @@ -129,7 +129,7 @@ cudaStream_t CudaStreamForMPI(); * has the fully accumulated Segment 1; and so on. The scatter-reduce is * complete. * - * Next, the allgather distributes these fully accumululated chunks across all + * Next, the allgather distributes these fully accumulated chunks across all * nodes. Communication proceeds in the same ring, once again in N-1 steps. At * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i). * For example, at the first iteration, the following transfers will occur: diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index f4ac70eb1a720c2acc3ef942f269228156749cba..f30643cf3059754daaeee4093938ac47b26f76ea 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -14,6 +14,7 @@ py_library( name = "opt_py", srcs = [ "__init__.py", + "python/training/adam_gs_optimizer.py", "python/training/adamax.py", "python/training/addsign.py", "python/training/agn_optimizer.py", @@ -22,6 +23,7 @@ py_library( "python/training/external_optimizer.py", "python/training/ggt.py", "python/training/lars_optimizer.py", + "python/training/lazy_adam_gs_optimizer.py", "python/training/lazy_adam_optimizer.py", "python/training/matrix_functions.py", "python/training/model_average_optimizer.py", @@ -60,6 +62,21 @@ py_library( ], ) +py_test( + name = "adam_gs_optimizer_test", + srcs = ["python/training/adam_gs_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + py_test( name = "adamax_test", srcs = ["python/training/adamax_test.py"], @@ -148,6 +165,25 @@ py_test( ], ) +py_test( + name = "lazy_adam_gs_optimizer_test", + srcs = ["python/training/lazy_adam_gs_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + py_test( name = "lazy_adam_optimizer_test", srcs = ["python/training/lazy_adam_optimizer_test.py"], @@ -283,6 +319,9 @@ tf_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//third_party/py/numpy", ], + tags = [ + "oss_serial", + ], ) tf_py_test( @@ -374,8 +413,9 @@ py_test( py_test( name = "shampoo_test", - size = "large", + size = "medium", srcs = ["python/training/shampoo_test.py"], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":opt_py", diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index c7ea68efa9a13a471bba3f41d0600855793b20a2..e8fc52342ceabb47da97ca0f3c8a01e419a221a1 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import +from tensorflow.contrib.opt.python.training.adam_gs_optimizer import * from tensorflow.contrib.opt.python.training.adamax import * from tensorflow.contrib.opt.python.training.addsign import * from tensorflow.contrib.opt.python.training.agn_optimizer import * @@ -28,6 +29,7 @@ from tensorflow.contrib.opt.python.training.external_optimizer import * from tensorflow.contrib.opt.python.training.lars_optimizer import * from tensorflow.contrib.opt.python.training.ggt import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * +from tensorflow.contrib.opt.python.training.lazy_adam_gs_optimizer import * from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * @@ -44,12 +46,14 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'AdaMaxOptimizer', + 'AdamGSOptimizer', 'PowerSignOptimizer', 'AddSignOptimizer', 'DelayCompensatedGradientDescentOptimizer', 'DropStaleGradientOptimizer', 'ExternalOptimizerInterface', 'LARSOptimizer', + 'LazyAdamGSOptimizer', 'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer', diff --git a/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0b149ed17533adff3bd7cd8fd8ff94d171f72911 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py @@ -0,0 +1,223 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Adam rewrite to use global step for computing beta1 & beta2 accumulation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("train.AdamOptimizer") +class AdamGSOptimizer(optimizer.Optimizer): + """Optimizer that implements the Adam algorithm. + + See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) + ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). + """ + + def __init__(self, + global_step=0, + learning_rate=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8, + use_locking=False, + name="Adam"): + r"""Construct a new Adam optimizer. + + Branched from tf.train.AdamOptimizer. The only difference is to pass + global step for computing beta1 and beta2 accumulators, instead of having + optimizer keep its own independent beta1 and beta2 accumulators as non-slot + variables. + + Initialization: + + $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$ + $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$ + $$t := 0 \text{(Initialize timestep)}$$ + + The update rule for `variable` with gradient `g` uses an optimization + described at the end of section2 of the paper: + + $$t := t + 1$$ + $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ + + $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ + $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ + $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ + + The default value of 1e-8 for epsilon might not be a good default in + general. For example, when training an Inception network on ImageNet a + current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the + formulation just before Section 2.1 of the Kingma and Ba paper rather than + the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon + hat" in the paper. + + The sparse implementation of this algorithm (used when the gradient is an + IndexedSlices object, typically because of `tf.gather` or an embedding + lookup in the forward pass) does apply momentum to variable slices even if + they were not used in the forward pass (meaning they have a gradient equal + to zero). Momentum decay (beta1) is also applied to the entire momentum + accumulator. This means that the sparse behavior is equivalent to the dense + behavior (in contrast to some momentum implementations which ignore momentum + unless a variable slice was actually used). + + Args: + global_step: tensorflow variable indicating the step. + learning_rate: A Tensor or a floating point value. The learning rate. + beta1: A float value or a constant float tensor. The exponential decay + rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. The exponential decay + rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". @compatibility(eager) When eager execution is + enabled, `learning_rate`, `beta1`, `beta2`, and `epsilon` can each be a + callable that takes no arguments and returns the actual value to use. + This can be useful for changing these values across different + invocations of optimizer functions. @end_compatibility + """ + super(AdamGSOptimizer, self).__init__(use_locking, name) + self._lr = learning_rate + self._beta1 = beta1 + self._beta2 = beta2 + self._epsilon = epsilon + self._global_step = global_step + self._global_step_on_worker = None + + # Tensor versions of the constructor arguments, created in _prepare(). + self._lr_t = None + self._beta1_t = None + self._beta2_t = None + self._epsilon_t = None + + def _get_beta_accumulators(self): + return (math_ops.pow(self._beta1_t, self._global_step_on_worker), + math_ops.pow(self._beta2_t, self._global_step_on_worker)) + + def _create_slots(self, var_list): + # Create slots for the first and second moments. + for v in var_list: + self._zeros_slot(v, "m", self._name) + self._zeros_slot(v, "v", self._name) + + def _prepare(self): + lr = self._call_if_callable(self._lr) + beta1 = self._call_if_callable(self._beta1) + beta2 = self._call_if_callable(self._beta2) + epsilon = self._call_if_callable(self._epsilon) + + self._lr_t = ops.convert_to_tensor(lr, name="learning_rate") + self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") + self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") + self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") + + # Performance optimization so that worker creates a copy of the global step + # to avoid overloading the parameter server holding the global step. + self._global_step_on_worker = math_ops.cast( + array_ops.identity(self._global_step) + 1, dtypes.float32) + + def _apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() + return training_ops.apply_adam( + var, + m, + v, + math_ops.cast(beta1_power, var.dtype.base_dtype), + math_ops.cast(beta2_power, var.dtype.base_dtype), + math_ops.cast(self._lr_t, var.dtype.base_dtype), + math_ops.cast(self._beta1_t, var.dtype.base_dtype), + math_ops.cast(self._beta2_t, var.dtype.base_dtype), + math_ops.cast(self._epsilon_t, var.dtype.base_dtype), + grad, + use_locking=self._use_locking).op + + def _resource_apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() + return training_ops.resource_apply_adam( + var.handle, + m.handle, + v.handle, + math_ops.cast(beta1_power, grad.dtype.base_dtype), + math_ops.cast(beta2_power, grad.dtype.base_dtype), + math_ops.cast(self._lr_t, grad.dtype.base_dtype), + math_ops.cast(self._beta1_t, grad.dtype.base_dtype), + math_ops.cast(self._beta2_t, grad.dtype.base_dtype), + math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), + grad, + use_locking=self._use_locking) + + def _apply_sparse_shared(self, grad, var, indices, scatter_add): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + # m_t = beta1 * m + (1 - beta1) * g_t + m = self.get_slot(var, "m") + m_scaled_g_values = grad * (1 - beta1_t) + m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) + with ops.control_dependencies([m_t]): + m_t = scatter_add(m, indices, m_scaled_g_values) + # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) + v = self.get_slot(var, "v") + v_scaled_g_values = (grad * grad) * (1 - beta2_t) + v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) + with ops.control_dependencies([v_t]): + v_t = scatter_add(v, indices, v_scaled_g_values) + v_sqrt = math_ops.sqrt(v_t) + var_update = state_ops.assign_sub( + var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) + return control_flow_ops.group(*[var_update, m_t, v_t]) + + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared( + grad.values, + var, + grad.indices, + lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda + x, + i, + v, + use_locking=self._use_locking)) + + def _resource_scatter_add(self, x, i, v): + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + return self._apply_sparse_shared(grad, var, indices, + self._resource_scatter_add) diff --git a/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py b/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c68c965aef3729bebe7d0e0dd707c344321d9e3f --- /dev/null +++ b/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py @@ -0,0 +1,382 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for AdamGS.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.opt.python.training import adam_gs_optimizer +from tensorflow.python.client import session +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class AdamGSOptimizerTest(test.TestCase): + + def doTestSparse(self, use_resource=False): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testSparse(self): + self.doTestSparse(use_resource=False) + + def testResourceSparse(self): + self.doTestSparse(use_resource=True) + + def testSparseDevicePlacement(self): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.cached_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + var = variables.Variable([[1.0], [2.0]]) + indices = constant_op.constant([0, 1], dtype=index_dtype) + gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + optimizer = adam_gs_optimizer.AdamGSOptimizer(3.0) + minimize_op = optimizer.minimize(gathered_sum) + variables.global_variables_initializer().run() + minimize_op.run() + + def testSparseRepeatedIndices(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + repeated_index_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update = adam_gs_optimizer.AdamGSOptimizer( + global_step=repeated_index_global_step).apply_gradients( + [(grad_repeated_index, repeated_index_update_var)], + global_step=repeated_index_global_step) + aggregated_update = adam_gs_optimizer.AdamGSOptimizer( + global_step=aggregated_global_step).apply_gradients( + [(grad_aggregated, aggregated_update_var)], + global_step=aggregated_global_step) + variables.global_variables_initializer().run() + self.assertAllClose(aggregated_update_var.eval(), + self.evaluate(repeated_index_update_var)) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + self.evaluate(repeated_index_update_var)) + + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64), name="global_step_%d" % i) + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step, + learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertTrue(beta1_power is not None) + self.assertTrue(beta2_power is not None) + self.assertNotIn(beta1_power, opt_variables) + self.assertNotIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta2_power)) + else: + if t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertAllCloseAccordingToType( + 0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**t, self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.cached_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adam_gs_optimizer.AdamGSOptimizer( + global_step=global_step, learning_rate=constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step) + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testTwoSessions(self): + optimizer = adam_gs_optimizer.AdamGSOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with session.Session(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with session.Session(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = adam_gs_optimizer.AdamGSOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two unique slot variables for v1 and v2 respectively. + self.assertEqual(4, len(set(opt.variables()))) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8827007e4d7f6722398a8e36bd626377842d92ef --- /dev/null +++ b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py @@ -0,0 +1,114 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""LazyAdam rewrite to use global step for computing beta1 & beta2 accumulation. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.opt.python.training import adam_gs_optimizer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops + + +class LazyAdamGSOptimizer(adam_gs_optimizer.AdamGSOptimizer): + """Variant of the Adam optimizer that handles sparse updates more efficiently. + + Branched from tf.contrib.opt.LazyAdamGSOptimizer. The only difference is to + pass global step for computing beta1 and beta2 accumulators, instead of having + optimizer keep its own independent beta1 and beta2 accumulators as non-slot + variables. + + The original Adam algorithm maintains two moving-average accumulators for + each trainable variable; the accumulators are updated at every step. + This class provides lazier handling of gradient updates for sparse variables. + It only updates moving-average accumulators for sparse variable indices that + appear in the current batch, rather than updating the accumulators for all + indices. Compared with the original Adam optimizer, it can provide large + improvements in model training throughput for some applications. However, it + provides slightly different semantics than the original Adam algorithm, and + may lead to different empirical results. + """ + + def _apply_sparse(self, grad, var): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t = state_ops.scatter_update(m, grad.indices, + beta1_t * array_ops.gather(m, grad.indices) + + (1 - beta1_t) * grad.values, + use_locking=self._use_locking) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t = state_ops.scatter_update(v, grad.indices, + beta2_t * array_ops.gather(v, grad.indices) + + (1 - beta2_t) * math_ops.square(grad.values), + use_locking=self._use_locking) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + m_t_slice = array_ops.gather(m_t, grad.indices) + v_t_slice = array_ops.gather(v_t, grad.indices) + denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t + var_update = state_ops.scatter_sub(var, grad.indices, + lr * m_t_slice / denominator_slice, + use_locking=self._use_locking) + return control_flow_ops.group(var_update, m_t, v_t) + + def _resource_apply_sparse(self, grad, var, indices): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad + m_update_op = resource_variable_ops.resource_scatter_update(m.handle, + indices, + m_t_slice) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t_slice = (beta2_t * array_ops.gather(v, indices) + + (1 - beta2_t) * math_ops.square(grad)) + v_update_op = resource_variable_ops.resource_scatter_update(v.handle, + indices, + v_t_slice) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) + var_update_op = resource_variable_ops.resource_scatter_sub(var.handle, + indices, + var_slice) + + return control_flow_ops.group(var_update_op, m_update_op, v_update_op) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc9a02a546c8399172d0c5b58941b4d80179955 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py @@ -0,0 +1,402 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for LazyAdamGSOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.opt.python.training import lazy_adam_gs_optimizer +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class LazyAdamGSOptimizerTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters([False, True]) + def testSparse(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + @parameterized.parameters([False, True]) + def testSparseDevicePlacement(self, use_resource): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.cached_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var = resource_variable_ops.ResourceVariable([[1.0], [2.0]]) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var = variables.Variable([[1.0], [2.0]]) + + indices = constant_op.constant([0, 1], dtype=index_dtype) + gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + optimizer = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=3.0) + minimize_op = optimizer.minimize(gathered_sum, global_step=global_step) + variables.global_variables_initializer().run() + minimize_op.run() + + @parameterized.parameters([False, True]) + def testSparseRepeatedIndices(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + if use_resource: + repeated_index_global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + else: + repeated_index_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update_opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=repeated_index_global_step) + repeated_update = repeated_update_opt.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)], + global_step=repeated_index_global_step) + aggregated_update_opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=aggregated_global_step) + aggregated_update = aggregated_update_opt.apply_gradients( + [(grad_aggregated, aggregated_update_var)], + global_step=aggregated_global_step) + variables.global_variables_initializer().run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64), name="global_step_%d" % i) + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertIsNotNone(beta1_power) + self.assertIsNotNone(beta2_power is not None) + self.assertNotIn(beta1_power, opt_variables) + self.assertNotIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta2_power)) + else: + if t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertAllCloseAccordingToType( + 0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**t, self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.cached_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step) + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testTwoSessions(self): + optimizer = lazy_adam_gs_optimizer.LazyAdamGSOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with self.session(graph=g): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with self.session(graph=gg): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two non-slot variables, and two unique slot variables + # for v1 and v2 respectively. + self.assertLen(set(opt.variables()), 4) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py index 248ffb1f7eb5dc27112ddf9b8670344904065ed0..1b7800f324b908e3c88fe90d31a2a08cbbd5ccf2 100644 --- a/tensorflow/contrib/optimizer_v2/adam.py +++ b/tensorflow/contrib/optimizer_v2/adam.py @@ -36,7 +36,7 @@ class AdamOptimizer(optimizer_v2.OptimizerV2): def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, name="Adam"): - """Construct a new Adam optimizer. + r"""Construct a new Adam optimizer. Initialization: diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 72019b31540a943582ebb4699013d9dcfc10769f..b2ea3daf82ed8daa6e0b9acd8e3cf258b8181615 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -44,14 +44,15 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import graph_view +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util -class NonLayerCheckpointable(tracking.Checkpointable): +class NonLayerTrackable(tracking.AutoTrackable): def __init__(self): - super(NonLayerCheckpointable, self).__init__() + super(NonLayerTrackable, self).__init__() self.a_variable = util.add_variable( self, name="a_variable", shape=[]) @@ -64,8 +65,8 @@ class MyModel(training.Model): super(MyModel, self).__init__() self._named_dense = core.Dense(1, use_bias=True) self._second = core.Dense(1, use_bias=False) - # We can still track Checkpointables which aren't Layers. - self._non_layer = NonLayerCheckpointable() + # We can still track Trackables which aren't Layers. + self._non_layer = NonLayerTrackable() def call(self, values): ret = self._second(self._named_dense(values)) @@ -100,7 +101,7 @@ class CheckpointingTests(test.TestCase): other_model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = util.Checkpoint( + root_trackable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) if context.executing_eagerly(): optimizer.minimize( @@ -116,11 +117,10 @@ class CheckpointingTests(test.TestCase): other_model(input_value), global_step=optimizer_step) self.evaluate(util.gather_initializers( - root_checkpointable)) + root_trackable)) self.evaluate(train_op) - named_variables, serialized_graph, _ = ( - util._serialize_object_graph( - root_checkpointable, saveables_cache=None)) + named_variables, serialized_graph, _ = graph_view.ObjectGraphView( + root_trackable).serialize_object_graph() expected_checkpoint_names = ( # Created in the root node, so no prefix. "optimizer_step", @@ -208,7 +208,7 @@ class CheckpointingTests(test.TestCase): def testSaveRestore(self): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root_checkpointable = util.Checkpoint( + root_trackable = util.Checkpoint( optimizer=optimizer, model=model) input_value = constant_op.constant([[3.]]) if context.executing_eagerly(): @@ -217,24 +217,24 @@ class CheckpointingTests(test.TestCase): else: train_op = optimizer.minimize(model(input_value)) # TODO(allenl): Make initialization more pleasant when graph building. - root_checkpointable.save_counter # pylint: disable=pointless-statement + root_trackable.save_counter # pylint: disable=pointless-statement self.evaluate(util.gather_initializers( - root_checkpointable)) + root_trackable)) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.])) m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m") self.evaluate(state_ops.assign(m_bias_slot, [1.5])) - save_path = root_checkpointable.save(file_prefix=prefix) + save_path = root_trackable.save(file_prefix=prefix) self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.])) - self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3)) + self.evaluate(state_ops.assign(root_trackable.save_counter, 3)) optimizer_variables = self.evaluate(optimizer.variables()) self.evaluate(state_ops.assign(m_bias_slot, [-2.])) # Immediate restoration - status = root_checkpointable.restore(save_path=save_path).assert_consumed() + status = root_trackable.restore(save_path=save_path).assert_consumed() status.run_restore_ops() self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1])) - self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter)) + self.assertAllEqual(1, self.evaluate(root_trackable.save_counter)) self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) if not context.executing_eagerly(): return # Restore-on-create is only supported when executing eagerly @@ -440,7 +440,7 @@ class CheckpointingTests(test.TestCase): def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() - root = tracking.Checkpointable() + root = util.Checkpoint() root.var = util.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) @@ -455,21 +455,17 @@ class CheckpointingTests(test.TestCase): util.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) - no_slots_path = util.CheckpointableSaver(root).save( - os.path.join(checkpoint_directory, "no_slots")) + no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) - slots_path = util.CheckpointableSaver(root).save( - os.path.join(checkpoint_directory, "with_slots")) - new_root = tracking.Checkpointable() + slots_path = root.save(os.path.join(checkpoint_directory, "with_slots")) + new_root = util.Checkpoint() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). - slot_status = util.CheckpointableSaver( - new_root).restore(slots_path) - no_slot_status = util.CheckpointableSaver( - new_root).restore(no_slots_path) + slot_status = new_root.restore(slots_path) + no_slot_status = new_root.restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = util.add_variable( @@ -508,15 +504,14 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = tracking.Checkpointable() + obj = util.Checkpoint() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) self.evaluate(util.gather_initializers(obj)) - saver = util.CheckpointableSaver(obj) - saver.save(checkpoint_prefix) + obj.save(checkpoint_prefix) before_ops = graph.get_operations() - saver.save(checkpoint_prefix) + obj.save(checkpoint_prefix) self.assertEqual(before_ops, graph.get_operations()) def testManyRestoresGraph(self): @@ -526,16 +521,15 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = tracking.Checkpointable() + obj = util.Checkpoint() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) self.evaluate(util.gather_initializers(obj)) - saver = util.CheckpointableSaver(obj) - save_path = saver.save(checkpoint_prefix) - saver.restore(save_path) + save_path = obj.save(checkpoint_prefix) + obj.restore(save_path) before_ops = graph.get_operations() - saver.restore(save_path) + obj.restore(save_path) self.assertEqual(before_ops, graph.get_operations()) def testMultipleGraphsNonSlotVariables(self): @@ -548,11 +542,11 @@ class CheckpointingTests(test.TestCase): first_session = session_lib.Session(graph=first_graph) with first_graph.as_default(), first_session.as_default(): first_variable = resource_variable_ops.ResourceVariable([1.]) - first_root_checkpointable = util.Checkpoint( + first_root_trackable = util.Checkpoint( optimizer=optimizer, variable=first_variable) train_op = optimizer.minimize(first_variable.read_value) self.evaluate(util.gather_initializers( - first_root_checkpointable)) + first_root_trackable)) self.evaluate(train_op) self.evaluate(first_variable.assign([1.])) self.evaluate(optimizer.get_slot( @@ -564,23 +558,23 @@ class CheckpointingTests(test.TestCase): second_graph = ops.Graph() with second_graph.as_default(), session_lib.Session(graph=second_graph): second_variable = resource_variable_ops.ResourceVariable([1.]) - second_root_checkpointable = util.Checkpoint( + second_root_trackable = util.Checkpoint( optimizer=optimizer, variable=second_variable) train_op = optimizer.minimize(second_variable.read_value) - second_root_checkpointable.restore(None).initialize_or_restore() + second_root_trackable.restore(None).initialize_or_restore() self.evaluate(train_op) self.evaluate(second_variable.assign([4.])) self.evaluate(optimizer.get_slot( var=second_variable, name="m").assign([5.])) beta_1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta_1_power.assign(6.)) - save_path = second_root_checkpointable.save(checkpoint_prefix) + save_path = second_root_trackable.save(checkpoint_prefix) self.evaluate(second_variable.assign([7.])) self.evaluate(optimizer.get_slot( var=second_variable, name="m").assign([8.])) beta_1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(6., self.evaluate(beta_1_power)) - status = second_root_checkpointable.restore(save_path) + status = second_root_trackable.restore(save_path) status.assert_consumed().run_restore_ops() self.assertAllEqual([4.], self.evaluate(second_variable)) self.assertAllEqual([5.], self.evaluate(optimizer.get_slot( @@ -600,7 +594,7 @@ class CheckpointingTests(test.TestCase): class TemplateTests(test.TestCase): @test_util.run_in_graph_and_eager_modes - def test_checkpointable_save_restore(self): + def test_trackable_save_restore(self): def _templated(): v = variable_scope.get_variable( @@ -647,13 +641,13 @@ class CheckpointCompatibilityTests(test.TestCase): model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = util.Checkpoint( + root_trackable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) train_op = optimizer.minimize( functools.partial(model, input_value), global_step=optimizer_step) self.evaluate(util.gather_initializers( - root_checkpointable)) + root_trackable)) self.evaluate(train_op) # A regular variable, a slot variable, and a non-slot Optimizer variable # with known values to check when loading. @@ -662,24 +656,24 @@ class CheckpointCompatibilityTests(test.TestCase): var=model._named_dense.bias, name="m").assign([2.])) beta_1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta_1_power.assign(3.)) - return root_checkpointable + return root_trackable - def _set_sentinels(self, root_checkpointable): - self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.])) + def _set_sentinels(self, root_trackable): + self.evaluate(root_trackable.model._named_dense.bias.assign([101.])) self.evaluate( - root_checkpointable.optimizer.get_slot( - var=root_checkpointable.model._named_dense.bias, name="m") + root_trackable.optimizer.get_slot( + var=root_trackable.model._named_dense.bias, name="m") .assign([102.])) - beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + beta_1_power, _ = root_trackable.optimizer._get_beta_accumulators() self.evaluate(beta_1_power.assign(103.)) - def _check_sentinels(self, root_checkpointable): + def _check_sentinels(self, root_trackable): self.assertAllEqual( - [1.], self.evaluate(root_checkpointable.model._named_dense.bias)) + [1.], self.evaluate(root_trackable.model._named_dense.bias)) self.assertAllEqual([2.], self.evaluate( - root_checkpointable.optimizer.get_slot( - var=root_checkpointable.model._named_dense.bias, name="m"))) - beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + root_trackable.optimizer.get_slot( + var=root_trackable.model._named_dense.bias, name="m"))) + beta_1_power, _ = root_trackable.optimizer._get_beta_accumulators() self.assertAllEqual(3., self.evaluate(beta_1_power)) def _write_name_based_checkpoint(self): @@ -704,7 +698,7 @@ class CheckpointCompatibilityTests(test.TestCase): self._set_sentinels(root) with self.assertRaises(AssertionError): self._check_sentinels(root) - object_saver = util.CheckpointableSaver(root) + object_saver = util.TrackableSaver(graph_view.ObjectGraphView(root)) self._set_sentinels(root) status = object_saver.restore(save_path) if context.executing_eagerly(): diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 7fb23abc38d9dc101204ed83808aebe5a8ef1e78..a7f978634ed45012144b2cc49ed069f6fca44f66 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -24,7 +24,6 @@ import abc import six -from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.eager import backprop @@ -39,7 +38,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import nest @@ -224,7 +223,7 @@ class _OptimizerV2State(object): } self._slots = {} self._non_slot_dict = {} - # Extra state to help Optimizers implement Checkpointable. Holds information + # Extra state to help Optimizers implement Trackable. Holds information # about variables which will be restored as soon as they're created. self._deferred_dependencies = {} # Non-slot variables self._deferred_slot_restorations = {} # Slot variables @@ -367,8 +366,8 @@ class _OptimizerV2State(object): slot variable needs to be restored). Args: - slot_variable_position: A `checkpointable._CheckpointPosition` object - indicating the slot variable `Checkpointable` object to be restored. + slot_variable_position: A `trackable._CheckpointPosition` object + indicating the slot variable `Trackable` object to be restored. slot_name: The name of this `Optimizer`'s slot to restore into. variable: The variable object this slot is being created for. optional_op_name: Name to use when scoping the Variable that needs to be @@ -386,7 +385,7 @@ class _OptimizerV2State(object): # (aside from double initialization), and makes variable creator scopes # behave the same way they do when graph building. and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access - initializer = checkpointable.CheckpointInitialValue( + initializer = trackable.CheckpointInitialValue( checkpoint_position=slot_variable_position) slot_variable = self.create_slot( var=variable, @@ -661,7 +660,7 @@ class OptimizerV2(optimizer_v1.Optimizer): name=None, grad_loss=None, stop_gradients=None, - scale_loss_by_num_replicas=None): + scale_loss_by_num_replicas=False): """Add operations to minimize `loss` by updating `var_list`. This method simply combines calls `compute_gradients()` and @@ -685,8 +684,7 @@ class OptimizerV2(optimizer_v1.Optimizer): stop_gradients: Optional. A Tensor or list of tensors not to differentiate through. scale_loss_by_num_replicas: Optional boolean. If true, scale the loss down - by the number of replicas. By default, auto-detects whether this is - needed. + by the number of replicas. DEPRECATED and generally no longer needed. Returns: An Operation that updates the variables in `var_list`. If `global_step` @@ -732,7 +730,7 @@ class OptimizerV2(optimizer_v1.Optimizer): aggregation_method=None, grad_loss=None, stop_gradients=None, - scale_loss_by_num_replicas=None): + scale_loss_by_num_replicas=False): """Compute gradients of `loss` for the variables in `var_list`. This is the first part of `minimize()`. It returns a list @@ -756,8 +754,7 @@ class OptimizerV2(optimizer_v1.Optimizer): stop_gradients: Optional. A Tensor or list of tensors not to differentiate through. scale_loss_by_num_replicas: Optional boolean. If true, scale the loss down - by the number of replicas. By default, auto-detects whether this is - needed. + by the number of replicas. DEPRECATED and generally no longer needed. Returns: A list of (gradient, variable) pairs. Variable is always present, but @@ -781,9 +778,7 @@ class OptimizerV2(optimizer_v1.Optimizer): tape.watch(var_list) loss_value = loss() - # Scale loss for number of replicas (callable-loss case). In this case, - # we have to be careful to call distribute_lib.get_loss_reduction() - # *after* loss() is evaluated, so we know what loss reduction it uses. + # Scale loss for number of replicas (callable-loss case). loss_value = self._scale_loss(loss_value, scale_loss_by_num_replicas) if var_list is None: @@ -839,12 +834,8 @@ class OptimizerV2(optimizer_v1.Optimizer): @staticmethod def _scale_loss(loss_value, scale_loss_by_num_replicas): """Scale loss for the number of replicas.""" - if scale_loss_by_num_replicas is None: - scale_loss_by_num_replicas = ( - distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN) if scale_loss_by_num_replicas: - num_replicas = \ - distribute_ctx.get_distribution_strategy().num_replicas_in_sync + num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync if num_replicas > 1: loss_value *= 1. / num_replicas return loss_value @@ -1268,10 +1259,10 @@ class OptimizerV2(optimizer_v1.Optimizer): return self._per_graph_state.get(var._graph_key, None) # -------------- - # Overridden methods from Checkpointable. + # Overridden methods from Trackable. # -------------- - def _track_checkpointable(self, *args, **kwargs): + def _track_trackable(self, *args, **kwargs): """Optimizers may not track dependencies. Raises an error.""" raise NotImplementedError( "Optimizers may not have dependencies. File a feature request if this " @@ -1279,7 +1270,7 @@ class OptimizerV2(optimizer_v1.Optimizer): @property def _checkpoint_dependencies(self): - """From Checkpointable. Gather graph-specific non-slot variables to save.""" + """From Trackable. Gather graph-specific non-slot variables to save.""" current_graph_non_slot_variables = [] state = self._get_per_graph_state() if state is not None: @@ -1288,14 +1279,14 @@ class OptimizerV2(optimizer_v1.Optimizer): # Avoid comparing variables key=lambda item: item[0]): current_graph_non_slot_variables.append( - checkpointable.CheckpointableReference( + trackable.TrackableReference( name=name, ref=variable_object)) # Note: ignores super(); Optimizers may not have any dependencies outside of # state objects. return current_graph_non_slot_variables def _lookup_dependency(self, name): - """From Checkpointable. Find a non-slot variable in the current graph.""" + """From Trackable. Find a non-slot variable in the current graph.""" state = self._get_per_graph_state() if state is None: return None @@ -1304,10 +1295,10 @@ class OptimizerV2(optimizer_v1.Optimizer): @property def _deferred_dependencies(self): - """Lets Checkpointable know where non-slot variables are created. + """Lets Trackable know where non-slot variables are created. If necessary, creates a new state object for the current default graph. - Checkpointable will then add entries to that state's deferred dependency + Trackable will then add entries to that state's deferred dependency dictionary. The state object will check that dictionary when creating non-slot variables, restoring their value if an entry is found. @@ -1320,14 +1311,14 @@ class OptimizerV2(optimizer_v1.Optimizer): def _create_or_restore_slot_variable(self, slot_variable_position, slot_name, variable): - """Checkpointable: Restore a slot variable's value, possibly creating it. + """Trackable: Restore a slot variable's value, possibly creating it. Called when a variable which has an associated slot variable is created or restored. Args: - slot_variable_position: A `checkpointable._CheckpointPosition` object - indicating the slot variable `Checkpointable` object to be restored. + slot_variable_position: A `trackable._CheckpointPosition` object + indicating the slot variable `Trackable` object to be restored. slot_name: The name of this `Optimizer`'s slot to restore into. variable: The variable object this slot is being created for. """ diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py index dd7f2f44055a2e48e8a48d01c1da3a8e7513255d..2fc0b5ea4de2332ff3bf32f9a12a15eee566d5c4 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py @@ -26,7 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import gradients_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables @@ -71,7 +71,7 @@ class OptimizerTest(test.TestCase): opt_op = sgd_op.minimize( cost, global_step, [var0, var1], - aggregation_method=gradients_impl.AggregationMethod. + aggregation_method=gradients_util.AggregationMethod. EXPERIMENTAL_ACCUMULATE_N) variables.global_variables_initializer().run() diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py index 17b69c7b35dce130c45ab0aadb28be330b4bfb88..c8524e9871864e0b4fffbd549d1fe347714f072a 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py @@ -84,7 +84,10 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): values = field_dict[field.name] self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) - fd = field.value.DESCRIPTOR.fields_by_name[field.name] + if 'ext_value' in field.name: + fd = test_example_pb2.PrimitiveValue() + else: + fd = field.value.DESCRIPTOR.fields_by_name[field.name] # Values has the same shape as the input plus an extra # dimension for repeats. @@ -92,13 +95,16 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): # Nested messages are represented as TF strings, requiring # some special handling. - if field.name == 'message_value': + if field.name == 'message_value' or 'ext_value' in field.name: vs = [] for buf in values.flat: msg = test_example_pb2.PrimitiveValue() msg.ParseFromString(buf) vs.append(msg) - evs = getattr(field.value, field.name) + if 'ext_value' in field.name: + evs = field.value.Extensions[test_example_pb2.ext_value] + else: + evs = getattr(field.value, field.name) if len(vs) != len(evs): self.fail('Field %s decoded %d outputs, expected %d' % (fd.name, len(vs), len(evs))) @@ -223,7 +229,8 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): sanitize=False, force_disordered=True) - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + @parameterized.named_parameters( + *test_base.ProtoOpTestBase.named_parameters(extension=False)) def testPacked(self, case): # Now try with the packed serialization. # @@ -235,8 +242,7 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): # Note: float_format='.17g' is necessary to ensure preservation of # doubles and floats in text format. text_format.Parse( - text_format.MessageToString( - value, float_format='.17g'), + text_format.MessageToString(value, float_format='.17g'), test_example_pb2.PackedTestValue()).SerializeToString() for value in case.values ] diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py index 01b3ccc7fd3918c4ff910281289e31177e5a8097..5ec681ff55dbd18580761bb23e7017cfc9767b89 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py +++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py @@ -15,9 +15,6 @@ # ============================================================================= """Table-driven test for encode_proto op. -This test is run once with each of the *.TestCase.pbtxt files -in the test directory. - It tests that encode_proto is a lossless inverse of decode_proto (for the specified fields). """ @@ -145,7 +142,8 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): # loss of packing in the encoding). self.assertEqual(in_buf, out_buf) - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + @parameterized.named_parameters( + *test_base.ProtoOpTestBase.named_parameters(extension=False)) def testRoundtrip(self, case): in_bufs = [value.SerializeToString() for value in case.values] @@ -154,7 +152,8 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): return self._testRoundtrip( in_bufs, 'tensorflow.contrib.proto.TestValue', case.fields) - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + @parameterized.named_parameters( + *test_base.ProtoOpTestBase.named_parameters(extension=False)) def testRoundtripPacked(self, case): # Now try with the packed serialization. # We test the packed representations by loading the same test cases using diff --git a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py index 2950c7dfdc59a11ba7d2c07d8406bd4af26b5bd9..1a636486a1765ad9544b5cb5e52961cc47f92950 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py +++ b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py @@ -38,17 +38,18 @@ class ProtoOpTestBase(test.TestCase): ct.cdll.LoadLibrary(lib) @staticmethod - def named_parameters(): - return ( - ("defaults", ProtoOpTestBase.defaults_test_case()), - ("minmax", ProtoOpTestBase.minmax_test_case()), - ("nested", ProtoOpTestBase.nested_test_case()), - ("optional", ProtoOpTestBase.optional_test_case()), - ("promote", ProtoOpTestBase.promote_test_case()), - ("ragged", ProtoOpTestBase.ragged_test_case()), - ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), - ("simple", ProtoOpTestBase.simple_test_case()), - ) + def named_parameters(extension=True): + parameters = [("defaults", ProtoOpTestBase.defaults_test_case()), + ("minmax", ProtoOpTestBase.minmax_test_case()), + ("nested", ProtoOpTestBase.nested_test_case()), + ("optional", ProtoOpTestBase.optional_test_case()), + ("promote", ProtoOpTestBase.promote_test_case()), + ("ragged", ProtoOpTestBase.ragged_test_case()), + ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), + ("simple", ProtoOpTestBase.simple_test_case())] + if extension: + parameters.append(("extension", ProtoOpTestBase.extension_test_case())) + return parameters @staticmethod def defaults_test_case(): @@ -399,6 +400,21 @@ class ProtoOpTestBase(test.TestCase): field.value.bool_value.append(True) return test_case + @staticmethod + def extension_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + message_value = value.Extensions[test_example_pb2.ext_value].add() + message_value.double_value = 23.5 + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = test_example_pb2.ext_value.full_name + field.dtype = types_pb2.DT_STRING + message_value = field.value.Extensions[test_example_pb2.ext_value].add() + message_value.double_value = 23.5 + return test_case + @staticmethod def simple_test_case(): test_case = test_example_pb2.TestCase() diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto index 674d881220a1113631def47c5111e3ef401b99f3..b1ce66de4feb9c6666ca9ccf39403b4e12840fcf 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto +++ b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto @@ -61,6 +61,8 @@ message TestValue { optional sfixed64 sfixed64_value_with_default = 32 [default = 11]; optional sint32 sint32_value_with_default = 33 [default = 12]; optional sint64 sint64_value_with_default = 34 [default = 13]; + + extensions 100 to 199; } // A PackedTestValue looks exactly the same as a TestValue in the text format, @@ -68,7 +70,7 @@ message TestValue { // by loading the same test cases using this definition instead of TestValue. // // NOTE: This definition must be kept in sync with TestValue in every way except -// the packed=true declaration. +// the packed=true declaration and the lack of extensions. message PackedTestValue { repeated double double_value = 1 [packed = true]; repeated float float_value = 2 [packed = true]; @@ -132,6 +134,10 @@ message ExtraFields { optional bool bool_value = 1777; } +extend TestValue { + repeated PrimitiveValue ext_value = 100; +} + // The messages below are for yet-to-be created tests. message EnumValue { diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index b35c4fde1a2c704880e023a0c3ac1e0766493514..b67e68ea96a15f94e62050c92405eec4fe4be70f 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -202,8 +202,9 @@ py_test( py_test( name = "quantize_parameterized_test", - size = "large", + size = "medium", srcs = ["python/quantize_parameterized_test.py"], + shard_count = 4, srcs_version = "PY2AND3", # TODO(b/118839526): Re-enable msan test. tags = [ diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md index 9085d9fa719520ac84ef6f8e07d7fa335bef5605..b335e1af69b7b2e6020f8e745c43bb1bdc95a62d 100644 --- a/tensorflow/contrib/quantize/README.md +++ b/tensorflow/contrib/quantize/README.md @@ -8,9 +8,9 @@ for both training and inference. There are two aspects to this: For efficient inference, TensorFlow combines batch normalization with the preceding convolutional and fully-connected layers prior to quantization by -[folding batch norm layers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/python/fold_batch_norms.py){:.external}. +[folding batch norm layers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/python/fold_batch_norms.py){:.external}. -The quantization error is modeled using [fake quantization](../api_guides/python/array_ops.md#Fake_quantization) +The quantization error is modeled using [fake quantization](../../api_guides/python/array_ops.md#Fake_quantization) nodes to simulate the effect of quantization in the forward and backward passes. The forward-pass models quantization, while the backward-pass models quantization as a straight-through estimator. Both the forward- and backward-pass simulate the quantization @@ -105,12 +105,12 @@ toco \ --std_value=127.5 --mean_value=127.5 ``` -See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../lite/). +See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../../lite/). ## Quantized accuracy results -The following are results of trainiing some popular CNN models (Mobilenet-v1, +The following are results of training some popular CNN models (Mobilenet-v1, Mobilenet-v2, and Inception-v3) using this tool:

diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index e0c6da00d86fe4c5f881bcab7b444182da092b8f..a70f748fad60c6467946225ad5035caaf89c2aaf 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -454,7 +454,7 @@ def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor, strides=layer_op.get_attr('strides'), padding=layer_op.get_attr('padding'), use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'), - data_format=layer_op.get_attr('data_format'), + data_format=layer_op.get_attr('data_format').decode(), name=new_layer_name) elif layer_op.type == 'MatMul': return math_ops.matmul( @@ -867,7 +867,7 @@ class _OpCloner(object): strides=op.get_attr('strides'), padding=op.get_attr('padding'), use_cudnn_on_gpu=op.get_attr('use_cudnn_on_gpu'), - data_format=op.get_attr('data_format'), + data_format=op.get_attr('data_format').decode(), name=new_name).op def _CloneDepthwiseConv2d(self, op, inputs, new_name): diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index 8619708cdaecd78bcc7de0e8e0cbf2baa11bf6a2..39082cacf9770619cf5fb529ac9a0aad6e955c6d 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -224,8 +224,8 @@ def MovingAvgQuantize(inputs, None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope: scope.set_partitioner(None) input_shape = inputs.get_shape() - input_dim = len(input_shape) if per_channel: + input_dim = len(input_shape) # Only support quantizing 1-, 2- and 4-dimensional tensors. assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' ' scope: %s' % (input_shape, name_prefix)) diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py index 36d2af94e059cdc75b758bbf607d26c4e1ee73e9..c636c90d23a0f5a6de9d14085c824283cb41f6ca 100644 --- a/tensorflow/contrib/quantize/python/quant_ops_test.py +++ b/tensorflow/contrib/quantize/python/quant_ops_test.py @@ -63,6 +63,12 @@ class QuantOpsTest(googletest.TestCase): self.assertAlmostEqual(min_value, -0.5, delta=1e-3) self.assertAlmostEqual(max_value, 0.5, delta=1e-3) + def testMovingAvgQuantizeTrainingAssignNoShape(self): + min_value, max_value = self._GetMinMaxValues( + quant_ops.MovingAvgQuantize, [[-1, 1], [0, 0]], shape=None) + self.assertAlmostEqual(min_value, -0.5, delta=1e-3) + self.assertAlmostEqual(max_value, 0.5, delta=1e-3) + def testMovingAvgSymmetricQuantizeTrainingAssign(self): min_value, max_value = self._GetMinMaxValues( quant_ops.MovingAvgQuantize, [[-1, 0.5], [0, 0]], symmetric=True) @@ -109,10 +115,10 @@ class QuantOpsTest(googletest.TestCase): is_training=True, vars_collection=_MIN_MAX_VARS) - def _GetMinMaxValues(self, quantize_fn, input_values, **kwds): + def _GetMinMaxValues(self, quantize_fn, input_values, shape=(2), **kwds): g = ops.Graph() with session.Session(graph=g) as sess: - x = array_ops.placeholder(dtypes.float32, shape=[2]) + x = array_ops.placeholder(dtypes.float32, shape=shape) y = quantize_fn( x, init_min=0.0, diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 21d1b1213090273b5abd8e012f8711db98c94347..7c973fe597181b822e617db1f85a08f1b678e26f 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -685,7 +685,7 @@ def _InsertQuantOp(context, [1; 2^bits - 1] or wide range [0; 2^bits - 1]. producer_scope: The restriction of producer scope. If not None, the new op will be inserted only when the producer is in this scope. - consumer_scope: The restriction of producer scope. If not None, the new op + consumer_scope: The restriction of consumer scope. If not None, the new op will be inserted only when all the consumers are in this scope. Raises: ValueError: When producer operation is not directly connected to the diff --git a/tensorflow/contrib/receptive_field/README.md b/tensorflow/contrib/receptive_field/README.md index 79b015a9163f5727caa40b54579c71e57621c92f..d1c41e4c0a11028765c9fc0dc345cb29453baa31 100644 --- a/tensorflow/contrib/receptive_field/README.md +++ b/tensorflow/contrib/receptive_field/README.md @@ -185,5 +185,4 @@ Effective padding (vertical) = 1482 ## Authors -André Araujo (github id: andrefaraujo) and Mark Sandler (github id: -marksandler) +André Araujo (@andrefaraujo) and Mark Sandler (@marksandler) diff --git a/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py index d6fdd12bbe37fb0e0cb12f1d0adc3fce29b19e8a..72f98ccc32e945b48b5f1b570bcca323a5b5f48a 100644 --- a/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py +++ b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Computes Receptive Field (RF) information given a graph protobuf. - -For an example of usage, see accompanying file compute_rf.sh -""" +"""Computes Receptive Field (RF) information given a graph protobuf.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py index a298b4d49038468299b58140758c69675368e855..325929a5937ac60a6134fae064e7633a4c57473d 100644 --- a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py +++ b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py @@ -16,8 +16,6 @@ The receptive field (and related parameters) for the different models are printed to stdout, and may also optionally be written to a CSV file. - -For an example of usage, see rf_benchmark.sh """ from __future__ import absolute_import @@ -262,11 +260,11 @@ def _model_rf(graphdef, information will be computed. model_type: Type of model to be used, used only for printing purposes. csv_writer: A CSV writer for RF parameters, which is used if it is not None. - input_resolution: Input resolution to use when computing RF - parameters. This is important for the case where padding can only be - defined if the input resolution is known, which may happen if using SAME - padding. This is assumed the resolution for both height and width. If - None, we consider the resolution is unknown. + input_resolution: Input resolution to use when computing RF parameters. This + is important for the case where padding can only be defined if the input + resolution is known, which may happen if using SAME padding. This is + assumed the resolution for both height and width. If None, we consider the + resolution is unknown. """ for desired_end_point_key in desired_end_point_keys: print('- %s:' % desired_end_point_key) @@ -283,10 +281,10 @@ def _model_rf(graphdef, if (receptive_field_x == receptive_field_y) and ( effective_stride_x == effective_stride_y) and ( effective_padding_x == effective_padding_y): - print('Receptive field size = %5s, effective stride = %5s, effective ' - 'padding = %5s' % (str(receptive_field_x), - str(effective_stride_x), - str(effective_padding_x))) + print( + 'Receptive field size = %5s, effective stride = %5s, effective ' + 'padding = %5s' % (str(receptive_field_x), str(effective_stride_x), + str(effective_padding_x))) else: print('Receptive field size: horizontal = %5s, vertical = %5s. ' 'Effective stride: horizontal = %5s, vertical = %5s. Effective ' @@ -362,9 +360,8 @@ def _process_model_rf(model_type='resnet_v1_50', defined if the input resolution is known, which may happen if using SAME padding. The entries in the list are assumed the resolution for both height and width. If one of the elements in the list is None, we consider - it to mean that the resolution is unknown. If the list itself is None, - we use the default list [None, 224, 321]. - + it to mean that the resolution is unknown. If the list itself is None, we + use the default list [None, 224, 321]. """ # Process default value for this list. if input_resolutions is None: @@ -477,8 +474,8 @@ def _mobilenet_v1_rf(csv_writer=None): csv_writer: A CSV writer for RF parameters, which is used if it is not None. """ for model_type in _SUPPORTED_MOBILENETV1_VARIANTS: - with slim.arg_scope( - [slim.batch_norm, slim.dropout], is_training=False) as arg_sc: + with slim.arg_scope([slim.batch_norm, slim.dropout], + is_training=False) as arg_sc: _process_model_rf(model_type, csv_writer, arg_sc) diff --git a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py index 0e3c46f17d2e2a277418d39e31927db73a509670..92ae1021bc8f8fbf19ca7f7cbe208ecea18128e8 100644 --- a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py +++ b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py @@ -27,7 +27,8 @@ from tensorflow.python.platform import tf_logging as logging _UNCHANGED_RF_LAYER_OPS = [ "Add", "BiasAdd", "Cast", "Ceil", "ConcatV2", "Const", "Floor", "FusedBatchNorm", "Identity", "Log", "Mul", "Pow", "RealDiv", "Relu", - "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2", "LRN" + "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2", "LRN", + "GreaterEqual" ] # Different ways in which padding modes may be spelled. @@ -276,11 +277,11 @@ def get_layer_params(node, name_to_node, input_resolution=None, force=False): kernel_size_x, kernel_size_y = _conv_kernel_size(node, name_to_node) # Compute the padding for this node separately for each direction. total_padding_x, padding_x = _padding_size_conv_pool( - node, kernel_size_x, stride_x, input_resolution[1] - if input_resolution is not None else None) + node, kernel_size_x, stride_x, + input_resolution[1] if input_resolution is not None else None) total_padding_y, padding_y = _padding_size_conv_pool( - node, kernel_size_y, stride_y, input_resolution[0] - if input_resolution is not None else None) + node, kernel_size_y, stride_y, + input_resolution[0] if input_resolution is not None else None) elif node.op == "Pad": # Kernel and stride are simply 1 in this case. kernel_size_x = 1 @@ -294,11 +295,11 @@ def get_layer_params(node, name_to_node, input_resolution=None, force=False): kernel_size_x, kernel_size_y = _pool_kernel_size(node, name_to_node) # Compute the padding for this node separately for each direction. total_padding_x, padding_x = _padding_size_conv_pool( - node, kernel_size_x, stride_x, input_resolution[1] - if input_resolution is not None else None) + node, kernel_size_x, stride_x, + input_resolution[1] if input_resolution is not None else None) total_padding_y, padding_y = _padding_size_conv_pool( - node, kernel_size_y, stride_y, input_resolution[0] - if input_resolution is not None else None) + node, kernel_size_y, stride_y, + input_resolution[0] if input_resolution is not None else None) elif node.op in _UNCHANGED_RF_LAYER_OPS: # These nodes do not modify the RF parameters. kernel_size_x = 1 @@ -320,7 +321,7 @@ def get_layer_params(node, name_to_node, input_resolution=None, force=False): total_padding_y = None padding_y = None else: - raise ValueError("Unknown layer for operation '%s': %s" % (node.name, - node.op)) + raise ValueError( + "Unknown layer for operation '%s': %s" % (node.name, node.op)) return (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y, total_padding_x, total_padding_y) diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field.py b/tensorflow/contrib/receptive_field/python/util/receptive_field.py index b9bd2f09761ab10a62d37e8e2580b93b9b8a4453..9127c772c75279d9c8eacc5a17680beba9247d01 100644 --- a/tensorflow/contrib/receptive_field/python/util/receptive_field.py +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field.py @@ -12,12 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions to compute receptive field of a fully-convolutional network. - -Please refer to the following g3doc for detailed explanation on how this -computation is performed, and why it is important: -g3doc/photos/vision/features/delf/g3doc/rf_computation.md -""" +"""Functions to compute receptive field of a fully-convolutional network.""" from __future__ import absolute_import from __future__ import division @@ -96,8 +91,8 @@ class ReceptiveField(object): Args: y: An array of feature coordinates with shape `(..., d)`, where `d` is the number of dimensions of the coordinates. - axis: The dimensions for which to compute the input center coordinates. - If `None` (the default), compute the input center coordinates for all + axis: The dimensions for which to compute the input center coordinates. If + `None` (the default), compute the input center coordinates for all dimensions. Returns: @@ -127,8 +122,8 @@ class ReceptiveField(object): Args: x: An array of input center coordinates with shape `(..., d)`, where `d` is the number of dimensions of the coordinates. - axis: The dimensions for which to compute the feature coordinates. - If `None` (the default), compute the feature coordinates for all + axis: The dimensions for which to compute the feature coordinates. If + `None` (the default), compute the feature coordinates for all dimensions. Returns: @@ -274,14 +269,15 @@ def compute_receptive_field_from_graph_def(graph_def, continue # Get params for this layer. - (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, - padding_y, _, _) = parse_layer_parameters.get_layer_params( + (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y, + _, _) = parse_layer_parameters.get_layer_params( node, name_to_node, node_info[node.name].input_size) - logging.vlog(3, "kernel_size_x = %s, kernel_size_y = %s, " - "stride_x = %s, stride_y = %s, " - "padding_x = %s, padding_y = %s, input size = %s" % - (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, - padding_y, node_info[node.name].input_size)) + logging.vlog( + 3, "kernel_size_x = %s, kernel_size_y = %s, " + "stride_x = %s, stride_y = %s, " + "padding_x = %s, padding_y = %s, input size = %s" % + (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, + padding_y, node_info[node.name].input_size)) if padding_x is None or padding_y is None: undefined_padding = True @@ -352,15 +348,15 @@ def compute_receptive_field_from_graph_def(graph_def, raise ValueError( "Graph is not aligned since effective stride from different " "paths is different in vertical direction") - if (rf_sizes_x[inp_name] - 1 - ) / 2 - effective_paddings_x[inp_name] != ( - rf_size_input_x - 1) / 2 - effective_padding_input_x: + if (rf_sizes_x[inp_name] - + 1) / 2 - effective_paddings_x[inp_name] != ( + rf_size_input_x - 1) / 2 - effective_padding_input_x: raise ValueError( "Graph is not aligned since center shift from different " "paths is different in horizontal direction") - if (rf_sizes_y[inp_name] - 1 - ) / 2 - effective_paddings_y[inp_name] != ( - rf_size_input_y - 1) / 2 - effective_padding_input_y: + if (rf_sizes_y[inp_name] - + 1) / 2 - effective_paddings_y[inp_name] != ( + rf_size_input_y - 1) / 2 - effective_padding_input_y: raise ValueError( "Graph is not aligned since center shift from different " "paths is different in vertical direction") diff --git a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py index 2054367f0d1461c8868e3332d82322a8a3dd38af..7e79785d2867de586f0730373d4864602ef770ae 100644 --- a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py +++ b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py @@ -50,13 +50,13 @@ def remote_fused_graph_execute(inputs, if default_graph_input_tensor_type_shapes: for type_shape in default_graph_input_tensor_type_shapes: type_shape_proto = info_proto.default_graph_input_tensor_shape.add() - type_shape_proto.dtype = int(dtypes.as_dtype(type_shape[0])) + type_shape_proto.dtype = dtypes.as_dtype(type_shape[0]).as_datatype_enum for dim in type_shape[1]: type_shape_proto.shape.dim.add().size = dim if default_graph_output_tensor_type_shapes: for type_shape in default_graph_output_tensor_type_shapes: type_shape_proto = info_proto.default_graph_output_tensor_shape.add() - type_shape_proto.dtype = int(dtypes.as_dtype(type_shape[0])) + type_shape_proto.dtype = dtypes.as_dtype(type_shape[0]).as_datatype_enum for dim in type_shape[1]: type_shape_proto.shape.dim.add().size = dim diff --git a/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py index d8ca0eab276b39f025d018edebb78eed7a8433bb..cec4c3c23305034d167a248a637425507750064e 100644 --- a/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py +++ b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py @@ -164,6 +164,15 @@ class ResamplerOpsTest(xla_test.XLATestCase): expected = [[[0.0], [27.62]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.12], [0.27999997]], [[0.18000001], + [0.42000002]]]] + expected_grad_warp = [[[0., 0.], [22.60000038, 35.20000076]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # One of (x, y) is less than 0. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -171,11 +180,21 @@ class ResamplerOpsTest(xla_test.XLATestCase): input_np = np.array(input_data, dtype=dtype).reshape(input_shape) warp_shape = [1, 2, 2] + # -1 is out of bound for grad_warp. warp_data = [-1, 0.1, 0.7, 0.6] warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) expected = [[[0.0], [27.62]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.12], [0.27999997]], [[0.18000001], + [0.42000002]]]] + expected_grad_warp = [[[0., 0.], [22.60000038, 35.20000076]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # Both of (x, y) are greater than image size. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -183,11 +202,20 @@ class ResamplerOpsTest(xla_test.XLATestCase): input_np = np.array(input_data, dtype=dtype).reshape(input_shape) warp_shape = [1, 2, 2] + # -0.1 is *inbound* for grad_warp and grad_data, 2.1 is out of bound. warp_data = [-0.1, 0.1, 1.2, 2.1] warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) expected = [[[0.0], [0.0]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.81], [0.0]], [[0.09], [0.0]]]] + expected_grad_warp = [[[10.30, 2.7], [0.0, 0.0]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # One of (x, y) is greater than image size. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -200,6 +228,14 @@ class ResamplerOpsTest(xla_test.XLATestCase): expected = [[[0.0], [0.0]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.81], [0.81]], [[0.0], [0.08]]]] + expected_grad_warp = [[[-4.5, 9.5], [-9.9, 39.20]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index e124867415f94fb5052f34f50363ea718d71053b..24fa740d24502a28cb42c994715d09180ee99899 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -102,25 +102,6 @@ cuda_py_tests( xla_enabled = True, ) -cuda_py_tests( - name = "core_rnn_cell_test", - size = "medium", - srcs = ["python/kernel_tests/core_rnn_cell_test.py"], - additional_deps = [ - ":rnn_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:rnn", - "//tensorflow/python:rnn_cell", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], -) - cuda_py_tests( name = "rnn_test", size = "medium", @@ -143,32 +124,6 @@ cuda_py_tests( ], ) -cuda_py_tests( - name = "core_rnn_test", - size = "medium", - srcs = ["python/kernel_tests/core_rnn_test.py"], - additional_deps = [ - ":rnn_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:rnn", - "//tensorflow/python:tensor_array_ops", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/eager:context", - ], - shard_count = 10, -) - tf_py_test( name = "fused_rnn_cell_test", size = "medium", @@ -226,7 +181,10 @@ tf_custom_op_library( "kernels/lstm_ops_gpu.cu.cc", "kernels/lstm_ops.h", ], - deps = ["//tensorflow/core/kernels:eigen_helpers"], + deps = [ + "//tensorflow/core/kernels:eigen_contraction_kernel", + "//tensorflow/core/kernels:eigen_helpers", + ], ) tf_gen_op_wrapper_py( @@ -248,7 +206,10 @@ tf_custom_op_library( "kernels/gru_ops_gpu.cu.cc", "kernels/gru_ops.h", ], - deps = ["//tensorflow/core/kernels:eigen_helpers"], + deps = [ + "//tensorflow/core/kernels:eigen_contraction_kernel", + "//tensorflow/core/kernels:eigen_helpers", + ], ) tf_gen_op_wrapper_py( @@ -345,6 +306,7 @@ tf_kernel_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//tensorflow/core/kernels:eigen_helpers", "//third_party/eigen3", ], @@ -380,6 +342,13 @@ py_binary( name = "checkpoint_convert", srcs = ["python/tools/checkpoint_convert.py"], srcs_version = "PY2AND3", + deps = [":checkpoint_convert_lib"], +) + +py_library( + name = "checkpoint_convert_lib", + srcs = ["python/tools/checkpoint_convert.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_ops", @@ -398,7 +367,7 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":checkpoint_convert", + ":checkpoint_convert_lib", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", "//tensorflow/python:session", diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.h b/tensorflow/contrib/rnn/kernels/blas_gemm.h index d37210d4b81203287fb633adc309688a35d093bb..12f3182a6a8878aa27ee143fa6405903e3fc4ef3 100644 --- a/tensorflow/contrib/rnn/kernels/blas_gemm.h +++ b/tensorflow/contrib/rnn/kernels/blas_gemm.h @@ -21,6 +21,10 @@ limitations under the License. #include "tensorflow/core/kernels/eigen_activations.h" #include "tensorflow/core/platform/types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + namespace tensorflow { class OpKernelContext; namespace functor { diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py deleted file mode 100644 index 7d57b0413a3bb51c35e670ce3fdb2cc818f44a58..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ /dev/null @@ -1,1078 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for RNN cells.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -import numpy as np - -from tensorflow.contrib import rnn as contrib_rnn -from tensorflow.contrib.rnn.python.ops import core_rnn_cell -from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import rnn -from tensorflow.python.ops import rnn_cell_impl -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import util as checkpointable_utils - -# pylint: enable=protected-access -Linear = core_rnn_cell._Linear # pylint: disable=invalid-name - - -class RNNCellTest(test.TestCase): - - def testLinear(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(1.0)): - x = array_ops.zeros([1, 2]) - l = Linear([x], 2, False)([x]) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([l], {x.name: np.array([[1., 2.]])}) - self.assertAllClose(res[0], [[3.0, 3.0]]) - - # Checks prevent you from accidentally creating a shared function. - with self.assertRaises(ValueError): - l1 = Linear([x], 2, False)([x]) - - # But you can create a new one in a new scope and share the variables. - with variable_scope.variable_scope("l1") as new_scope: - l1 = Linear([x], 2, False)([x]) - with variable_scope.variable_scope(new_scope, reuse=True): - Linear([l1], 2, False)([l1]) - self.assertEqual(len(variables_lib.trainable_variables()), 2) - - def testBasicRNNCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - cell = rnn_cell_impl.BasicRNNCell(2) - g, _ = cell(x, m) - self.assertEqual([ - "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME - ], [v.name for v in cell.trainable_variables]) - self.assertFalse(cell.non_trainable_variables) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - self.assertEqual(res[0].shape, (1, 2)) - - def testBasicRNNCellNotTrainable(self): - with self.cached_session() as sess: - - def not_trainable_getter(getter, *args, **kwargs): - kwargs["trainable"] = False - return getter(*args, **kwargs) - - with variable_scope.variable_scope( - "root", - initializer=init_ops.constant_initializer(0.5), - custom_getter=not_trainable_getter): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - cell = rnn_cell_impl.BasicRNNCell(2) - g, _ = cell(x, m) - self.assertFalse(cell.trainable_variables) - self.assertEqual([ - "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME - ], [v.name for v in cell.non_trainable_variables]) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - self.assertEqual(res[0].shape, (1, 2)) - - def testIndRNNCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - cell = contrib_rnn_cell.IndRNNCell(2) - g, _ = cell(x, m) - self.assertEqual([ - "root/ind_rnn_cell/%s_w:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/ind_rnn_cell/%s_u:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/ind_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME - ], [v.name for v in cell.trainable_variables]) - self.assertFalse(cell.non_trainable_variables) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - self.assertEqual(res[0].shape, (1, 2)) - - def testGRUCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - g, _ = rnn_cell_impl.GRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.175991, 0.175991]]) - with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)): - # Test GRUCell with input_size != num_units. - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 2]) - g, _ = rnn_cell_impl.GRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.156736, 0.156736]]) - - def testIndyGRUCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.185265, 0.17704]]) - with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)): - # Test IndyGRUCell with input_size != num_units. - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 2]) - g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.155127, 0.157328]]) - - def testSRUCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - g, _ = contrib_rnn_cell.SRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.509682, 0.509682]]) - - def testSRUCellWithDiffSize(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 2]) - g, _ = contrib_rnn_cell.SRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.55255556, 0.55255556]]) - - def testBasicLSTMCell(self): - for dtype in [dtypes.float16, dtypes.float32]: - np_dtype = dtype.as_numpy_dtype - with self.session(graph=ops.Graph()) as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2], dtype=dtype) - m = array_ops.zeros([1, 8], dtype=dtype) - cell = rnn_cell_impl.MultiRNNCell( - [ - rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) - for _ in range(2) - ], - state_is_tuple=False) - self.assertEqual(cell.dtype, None) - self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) - self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) - cell.get_config() # Should not throw an error - g, out_m = cell(x, m) - # Layer infers the input type. - self.assertEqual(cell.dtype, dtype.name) - expected_variable_names = [ - "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME - ] - self.assertEqual(expected_variable_names, - [v.name for v in cell.trainable_variables]) - self.assertFalse(cell.non_trainable_variables) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, out_m], { - x.name: np.array([[1., 1.]]), - m.name: 0.1 * np.ones([1, 8]) - }) - self.assertEqual(len(res), 2) - variables = variables_lib.global_variables() - self.assertEqual(expected_variable_names, [v.name for v in variables]) - # The numbers in results were not calculated, this is just a - # smoke test. - self.assertAllClose(res[0], np.array( - [[0.240, 0.240]], dtype=np_dtype), 1e-2) - expected_mem = np.array( - [[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]], - dtype=np_dtype) - self.assertAllClose(res[1], expected_mem, 1e-2) - with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)): - # Test BasicLSTMCell with input_size != num_units. - x = array_ops.zeros([1, 3], dtype=dtype) - m = array_ops.zeros([1, 4], dtype=dtype) - g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m], { - x.name: np.array([[1., 1., 1.]], dtype=np_dtype), - m.name: 0.1 * np.ones([1, 4], dtype=np_dtype) - }) - self.assertEqual(len(res), 2) - - def testBasicLSTMCellDimension0Error(self): - """Tests that dimension 0 in both(x and m) shape must be equal.""" - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - num_units = 2 - state_size = num_units * 2 - batch_size = 3 - input_size = 4 - x = array_ops.zeros([batch_size, input_size]) - m = array_ops.zeros([batch_size - 1, state_size]) - with self.assertRaises(ValueError): - g, out_m = rnn_cell_impl.BasicLSTMCell( - num_units, state_is_tuple=False)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - sess.run( - [g, out_m], { - x.name: 1 * np.ones([batch_size, input_size]), - m.name: 0.1 * np.ones([batch_size - 1, state_size]) - }) - - def testBasicLSTMCellStateSizeError(self): - """Tests that state_size must be num_units * 2.""" - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - num_units = 2 - state_size = num_units * 3 # state_size must be num_units * 2 - batch_size = 3 - input_size = 4 - x = array_ops.zeros([batch_size, input_size]) - m = array_ops.zeros([batch_size, state_size]) - with self.assertRaises(ValueError): - g, out_m = rnn_cell_impl.BasicLSTMCell( - num_units, state_is_tuple=False)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - sess.run( - [g, out_m], { - x.name: 1 * np.ones([batch_size, input_size]), - m.name: 0.1 * np.ones([batch_size, state_size]) - }) - - def testBasicLSTMCellStateTupleType(self): - with self.cached_session(): - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m0 = (array_ops.zeros([1, 2]),) * 2 - m1 = (array_ops.zeros([1, 2]),) * 2 - cell = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)], - state_is_tuple=True) - self.assertTrue(isinstance(cell.state_size, tuple)) - self.assertTrue( - isinstance(cell.state_size[0], rnn_cell_impl.LSTMStateTuple)) - self.assertTrue( - isinstance(cell.state_size[1], rnn_cell_impl.LSTMStateTuple)) - - # Pass in regular tuples - _, (out_m0, out_m1) = cell(x, (m0, m1)) - self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple)) - self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple)) - - # Pass in LSTMStateTuples - variable_scope.get_variable_scope().reuse_variables() - zero_state = cell.zero_state(1, dtypes.float32) - self.assertTrue(isinstance(zero_state, tuple)) - self.assertTrue(isinstance(zero_state[0], rnn_cell_impl.LSTMStateTuple)) - self.assertTrue(isinstance(zero_state[1], rnn_cell_impl.LSTMStateTuple)) - _, (out_m0, out_m1) = cell(x, zero_state) - self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple)) - self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple)) - - def testBasicLSTMCellWithStateTuple(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m0 = array_ops.zeros([1, 4]) - m1 = array_ops.zeros([1, 4]) - cell = rnn_cell_impl.MultiRNNCell( - [ - rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) - for _ in range(2) - ], - state_is_tuple=True) - g, (out_m0, out_m1) = cell(x, (m0, m1)) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m0, out_m1], { - x.name: np.array([[1., 1.]]), - m0.name: 0.1 * np.ones([1, 4]), - m1.name: 0.1 * np.ones([1, 4]) - }) - self.assertEqual(len(res), 3) - # The numbers in results were not calculated, this is just a smoke test. - # Note, however, these values should match the original - # version having state_is_tuple=False. - self.assertAllClose(res[0], [[0.24024698, 0.24024698]]) - expected_mem0 = np.array( - [[0.68967271, 0.68967271, 0.44848421, 0.44848421]]) - expected_mem1 = np.array( - [[0.39897051, 0.39897051, 0.24024698, 0.24024698]]) - self.assertAllClose(res[1], expected_mem0) - self.assertAllClose(res[2], expected_mem1) - - def testIndyLSTMCell(self): - for dtype in [dtypes.float16, dtypes.float32]: - np_dtype = dtype.as_numpy_dtype - with self.session(graph=ops.Graph()) as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2], dtype=dtype) - state_0 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 - state_1 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 - cell = rnn_cell_impl.MultiRNNCell( - [contrib_rnn_cell.IndyLSTMCell(2) for _ in range(2)]) - self.assertEqual(cell.dtype, None) - self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) - self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) - cell.get_config() # Should not throw an error - g, (out_state_0, out_state_1) = cell(x, (state_0, state_1)) - # Layer infers the input type. - self.assertEqual(cell.dtype, dtype.name) - expected_variable_names = [ - "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_w:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_u:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_w:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_u:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME - ] - self.assertEqual(expected_variable_names, - [v.name for v in cell.trainable_variables]) - self.assertFalse(cell.non_trainable_variables) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_state_0, out_state_1], { - x.name: np.array([[1., 1.]]), - state_0[0].name: 0.1 * np.ones([1, 2]), - state_0[1].name: 0.1 * np.ones([1, 2]), - state_1[0].name: 0.1 * np.ones([1, 2]), - state_1[1].name: 0.1 * np.ones([1, 2]), - }) - self.assertEqual(len(res), 3) - variables = variables_lib.global_variables() - self.assertEqual(expected_variable_names, [v.name for v in variables]) - # Only check the range of outputs as this is just a smoke test. - self.assertAllInRange(res[0], -1.0, 1.0) - self.assertAllInRange(res[1], -1.0, 1.0) - self.assertAllInRange(res[2], -1.0, 1.0) - with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)): - # Test IndyLSTMCell with input_size != num_units. - x = array_ops.zeros([1, 3], dtype=dtype) - state = (array_ops.zeros([1, 2], dtype=dtype),) * 2 - g, out_state = contrib_rnn_cell.IndyLSTMCell(2)(x, state) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_state], { - x.name: np.array([[1., 1., 1.]], dtype=np_dtype), - state[0].name: 0.1 * np.ones([1, 2], dtype=np_dtype), - state[1].name: 0.1 * np.ones([1, 2], dtype=np_dtype), - }) - self.assertEqual(len(res), 2) - - def testLSTMCell(self): - with self.cached_session() as sess: - num_units = 8 - num_proj = 6 - state_size = num_units + num_proj - batch_size = 3 - input_size = 2 - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([batch_size, input_size]) - m = array_ops.zeros([batch_size, state_size]) - cell = rnn_cell_impl.LSTMCell( - num_units=num_units, - num_proj=num_proj, - forget_bias=1.0, - state_is_tuple=False) - output, state = cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [output, state], { - x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]), - m.name: 0.1 * np.ones((batch_size, state_size)) - }) - self.assertEqual(len(res), 2) - # The numbers in results were not calculated, this is mostly just a - # smoke test. - self.assertEqual(res[0].shape, (batch_size, num_proj)) - self.assertEqual(res[1].shape, (batch_size, state_size)) - # Different inputs so different outputs and states - for i in range(1, batch_size): - self.assertTrue( - float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6) - self.assertTrue( - float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6) - - def testLSTMCellVariables(self): - with self.cached_session(): - num_units = 8 - num_proj = 6 - state_size = num_units + num_proj - batch_size = 3 - input_size = 2 - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([batch_size, input_size]) - m = array_ops.zeros([batch_size, state_size]) - cell = rnn_cell_impl.LSTMCell( - num_units=num_units, - num_proj=num_proj, - forget_bias=1.0, - state_is_tuple=False) - cell(x, m) # Execute to create variables - variables = variables_lib.global_variables() - self.assertEquals(variables[0].op.name, "root/lstm_cell/kernel") - self.assertEquals(variables[1].op.name, "root/lstm_cell/bias") - self.assertEquals(variables[2].op.name, - "root/lstm_cell/projection/kernel") - - def testLSTMCellLayerNorm(self): - with self.cached_session() as sess: - num_units = 2 - num_proj = 3 - batch_size = 1 - input_size = 4 - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([batch_size, input_size]) - c = array_ops.zeros([batch_size, num_units]) - h = array_ops.zeros([batch_size, num_proj]) - state = rnn_cell_impl.LSTMStateTuple(c, h) - cell = contrib_rnn_cell.LayerNormLSTMCell( - num_units=num_units, - num_proj=num_proj, - forget_bias=1.0, - layer_norm=True, - norm_gain=1.0, - norm_shift=0.0) - g, out_m = cell(x, state) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m], { - x.name: np.ones((batch_size, input_size)), - c.name: 0.1 * np.ones((batch_size, num_units)), - h.name: 0.1 * np.ones((batch_size, num_proj)) - }) - self.assertEqual(len(res), 2) - # The numbers in results were not calculated, this is mostly just a - # smoke test. - self.assertEqual(res[0].shape, (batch_size, num_proj)) - self.assertEqual(res[1][0].shape, (batch_size, num_units)) - self.assertEqual(res[1][1].shape, (batch_size, num_proj)) - # Different inputs so different outputs and states - for i in range(1, batch_size): - self.assertTrue( - float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6) - self.assertTrue( - float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testWrapperCheckpointing(self): - for wrapper_type in [ - rnn_cell_impl.DropoutWrapper, - rnn_cell_impl.ResidualWrapper, - lambda cell: rnn_cell_impl.MultiRNNCell([cell])]: - cell = rnn_cell_impl.BasicRNNCell(1) - wrapper = wrapper_type(cell) - wrapper(array_ops.ones([1, 1]), - state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) - self.evaluate([v.initializer for v in cell.variables]) - checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper) - prefix = os.path.join(self.get_temp_dir(), "ckpt") - self.evaluate(cell._bias.assign([40.])) - save_path = checkpoint.save(prefix) - self.evaluate(cell._bias.assign([0.])) - checkpoint.restore(save_path).assert_consumed().run_restore_ops() - self.assertAllEqual([40.], self.evaluate(cell._bias)) - - def testOutputProjectionWrapper(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 3]) - cell = contrib_rnn.OutputProjectionWrapper(rnn_cell_impl.GRUCell(3), 2) - g, new_m = cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, new_m], { - x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1, 0.1]]) - }) - self.assertEqual(res[1].shape, (1, 3)) - # The numbers in results were not calculated, this is just a smoke test. - self.assertAllClose(res[0], [[0.231907, 0.231907]]) - - def testInputProjectionWrapper(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 3]) - cell = contrib_rnn.InputProjectionWrapper( - rnn_cell_impl.GRUCell(3), num_proj=3) - g, new_m = cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, new_m], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1, 0.1]]) - }) - self.assertEqual(res[1].shape, (1, 3)) - # The numbers in results were not calculated, this is just a smoke test. - self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) - - def testResidualWrapper(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 3]) - base_cell = rnn_cell_impl.GRUCell(3) - g, m_new = base_cell(x, m) - variable_scope.get_variable_scope().reuse_variables() - wrapper_object = rnn_cell_impl.ResidualWrapper(base_cell) - (name, dep), = wrapper_object._checkpoint_dependencies - wrapper_object.get_config() # Should not throw an error - self.assertIs(dep, base_cell) - self.assertEqual("cell", name) - - g_res, m_new_res = wrapper_object(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, g_res, m_new, m_new_res], { - x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.1, 0.1]]) - }) - # Residual connections - self.assertAllClose(res[1], res[0] + [1., 1., 1.]) - # States are left untouched - self.assertAllClose(res[2], res[3]) - - def testResidualWrapperWithSlice(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 5]) - m = array_ops.zeros([1, 3]) - base_cell = rnn_cell_impl.GRUCell(3) - g, m_new = base_cell(x, m) - variable_scope.get_variable_scope().reuse_variables() - - def residual_with_slice_fn(inp, out): - inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3]) - return inp_sliced + out - - g_res, m_new_res = rnn_cell_impl.ResidualWrapper( - base_cell, residual_with_slice_fn)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res_g, res_g_res, res_m_new, res_m_new_res = sess.run( - [g, g_res, m_new, m_new_res], { - x: np.array([[1., 1., 1., 1., 1.]]), - m: np.array([[0.1, 0.1, 0.1]]) - }) - # Residual connections - self.assertAllClose(res_g_res, res_g + [1., 1., 1.]) - # States are left untouched - self.assertAllClose(res_m_new, res_m_new_res) - - def testDeviceWrapper(self): - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 3]) - wrapped = rnn_cell_impl.GRUCell(3) - cell = rnn_cell_impl.DeviceWrapper(wrapped, "/cpu:14159") - (name, dep), = cell._checkpoint_dependencies - cell.get_config() # Should not throw an error - self.assertIs(dep, wrapped) - self.assertEqual("cell", name) - - outputs, _ = cell(x, m) - self.assertTrue("cpu:14159" in outputs.device.lower()) - - def _retrieve_cpu_gpu_stats(self, run_metadata): - cpu_stats = None - gpu_stats = None - step_stats = run_metadata.step_stats - for ds in step_stats.dev_stats: - if "cpu:0" in ds.device[-5:].lower(): - cpu_stats = ds.node_stats - if "gpu:0" == ds.device[-5:].lower(): - gpu_stats = ds.node_stats - return cpu_stats, gpu_stats - - def testDeviceWrapperDynamicExecutionNodesAreAllProperlyLocated(self): - if not test.is_gpu_available(): - # Can't perform this test w/o a GPU - return - - gpu_dev = test.gpu_device_name() - with self.session(use_gpu=True) as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 1, 3]) - cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), gpu_dev) - with ops.device("/cpu:0"): - outputs, _ = rnn.dynamic_rnn( - cell=cell, inputs=x, dtype=dtypes.float32) - run_metadata = config_pb2.RunMetadata() - opts = config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE) - - sess.run([variables_lib.global_variables_initializer()]) - _ = sess.run(outputs, options=opts, run_metadata=run_metadata) - - cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) - self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name]) - self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name]) - - def testEmbeddingWrapper(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 1], dtype=dtypes.int32) - m = array_ops.zeros([1, 2]) - embedding_cell = contrib_rnn.EmbeddingWrapper( - rnn_cell_impl.GRUCell(2), embedding_classes=3, embedding_size=2) - self.assertEqual(embedding_cell.output_size, 2) - g, new_m = embedding_cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, new_m], { - x.name: np.array([[1]]), - m.name: np.array([[0.1, 0.1]]) - }) - self.assertEqual(res[1].shape, (1, 2)) - # The numbers in results were not calculated, this is just a smoke test. - self.assertAllClose(res[0], [[0.17139, 0.17139]]) - - def testEmbeddingWrapperWithDynamicRnn(self): - with self.cached_session() as sess: - with variable_scope.variable_scope("root"): - inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64) - input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64) - embedding_cell = contrib_rnn.EmbeddingWrapper( - rnn_cell_impl.BasicLSTMCell(1, state_is_tuple=True), - embedding_classes=1, - embedding_size=2) - outputs, _ = rnn.dynamic_rnn( - cell=embedding_cell, - inputs=inputs, - sequence_length=input_lengths, - dtype=dtypes.float32) - sess.run([variables_lib.global_variables_initializer()]) - # This will fail if output's dtype is inferred from input's. - sess.run(outputs) - - def testMultiRNNCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 4]) - multi_rnn_cell = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) for _ in range(2)], - state_is_tuple=False) - _, ml = multi_rnn_cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run(ml, { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1, 0.1, 0.1]]) - }) - # The numbers in results were not calculated, this is just a smoke test. - self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]]) - self.assertEqual(len(multi_rnn_cell.weights), 2 * 4) - self.assertTrue( - [x.dtype == dtypes.float32 for x in multi_rnn_cell.weights]) - - def testMultiRNNCellWithStateTuple(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m_bad = array_ops.zeros([1, 4]) - m_good = (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])) - - # Test incorrectness of state - with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"): - rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) for _ in range(2)], - state_is_tuple=True)(x, m_bad) - - _, ml = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) for _ in range(2)], - state_is_tuple=True)(x, m_good) - - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - ml, { - x.name: np.array([[1., 1.]]), - m_good[0].name: np.array([[0.1, 0.1]]), - m_good[1].name: np.array([[0.1, 0.1]]) - }) - - # The numbers in results were not calculated, this is just a - # smoke test. However, these numbers should match those of - # the test testMultiRNNCell. - self.assertAllClose(res[0], [[0.175991, 0.175991]]) - self.assertAllClose(res[1], [[0.13248, 0.13248]]) - - -class DropoutWrapperTest(test.TestCase): - - def _testDropoutWrapper(self, - batch_size=None, - time_steps=None, - parallel_iterations=None, - **kwargs): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - if batch_size is None and time_steps is None: - # 2 time steps, batch size 1, depth 3 - batch_size = 1 - time_steps = 2 - x = constant_op.constant( - [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32) - m = rnn_cell_impl.LSTMStateTuple( - *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32 - )] * 2) - else: - x = constant_op.constant( - np.random.randn(time_steps, batch_size, 3).astype(np.float32)) - m = rnn_cell_impl.LSTMStateTuple(*[ - constant_op. - constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32) - ] * 2) - outputs, final_state = rnn.dynamic_rnn( - cell=rnn_cell_impl.DropoutWrapper( - rnn_cell_impl.LSTMCell(3), dtype=x.dtype, **kwargs), - time_major=True, - parallel_iterations=parallel_iterations, - inputs=x, - initial_state=m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([outputs, final_state]) - self.assertEqual(res[0].shape, (time_steps, batch_size, 3)) - self.assertEqual(res[1].c.shape, (batch_size, 3)) - self.assertEqual(res[1].h.shape, (batch_size, 3)) - return res - - def testWrappedCellProperty(self): - cell = rnn_cell_impl.BasicRNNCell(10) - wrapper = rnn_cell_impl.DropoutWrapper(cell) - # Github issue 15810 - self.assertEqual(wrapper.wrapped_cell, cell) - - def testDropoutWrapperKeepAllConstantInput(self): - keep = array_ops.ones([]) - res = self._testDropoutWrapper( - input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - true_full_final_c = np.array( - [[1.949385, 1.949385, 1.949385]], dtype=np.float32) - self.assertAllClose(true_full_output, res[0]) - self.assertAllClose(true_full_output[1], res[1].h) - self.assertAllClose(true_full_final_c, res[1].c) - - def testDropoutWrapperKeepAll(self): - keep = variable_scope.get_variable("all", initializer=1.0) - res = self._testDropoutWrapper( - input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - true_full_final_c = np.array( - [[1.949385, 1.949385, 1.949385]], dtype=np.float32) - self.assertAllClose(true_full_output, res[0]) - self.assertAllClose(true_full_output[1], res[1].h) - self.assertAllClose(true_full_final_c, res[1].c) - - def testDropoutWrapperWithSeed(self): - keep_some = 0.5 - random_seed.set_random_seed(2) - ## Use parallel_iterations = 1 in both calls to - ## _testDropoutWrapper to ensure the (per-time step) dropout is - ## consistent across both calls. Otherwise the seed may not end - ## up being munged consistently across both graphs. - res_standard_1 = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - seed=10, - parallel_iterations=1) - # Clear away the graph and the test session (which keeps variables around) - ops.reset_default_graph() - self._ClearCachedSession() - random_seed.set_random_seed(2) - res_standard_2 = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - seed=10, - parallel_iterations=1) - self.assertAllClose(res_standard_1[0], res_standard_2[0]) - self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c) - self.assertAllClose(res_standard_1[1].h, res_standard_2[1].h) - - def testDropoutWrapperKeepNoOutput(self): - keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-6) - res = self._testDropoutWrapper( - input_keep_prob=keep_all, - output_keep_prob=keep_none, - state_keep_prob=keep_all) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - true_full_final_c = np.array( - [[1.949385, 1.949385, 1.949385]], dtype=np.float32) - self.assertAllClose(np.zeros(res[0].shape), res[0]) - self.assertAllClose(true_full_output[1], res[1].h) - self.assertAllClose(true_full_final_c, res[1].c) - - def testDropoutWrapperKeepNoStateExceptLSTMCellMemory(self): - keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-6) - # Even though we dropout state, by default DropoutWrapper never - # drops out the memory ("c") term of an LSTMStateTuple. - res = self._testDropoutWrapper( - input_keep_prob=keep_all, - output_keep_prob=keep_all, - state_keep_prob=keep_none) - true_c_state = np.array([[1.713925, 1.713925, 1.713925]], dtype=np.float32) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - self.assertAllClose(true_full_output[0], res[0][0]) - # Second output is modified by zero input state - self.assertGreater(np.linalg.norm(true_full_output[1] - res[0][1]), 1e-4) - # h state has been set to zero - self.assertAllClose(np.zeros(res[1].h.shape), res[1].h) - # c state of an LSTMStateTuple is NEVER modified. - self.assertAllClose(true_c_state, res[1].c) - - def testDropoutWrapperKeepNoInput(self): - keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-6) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - true_full_final_c = np.array( - [[1.949385, 1.949385, 1.949385]], dtype=np.float32) - # All outputs are different because inputs are zeroed out - res = self._testDropoutWrapper( - input_keep_prob=keep_none, - output_keep_prob=keep_all, - state_keep_prob=keep_all) - self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4) - self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4) - self.assertGreater(np.linalg.norm(res[1].c - true_full_final_c), 1e-4) - - def testDropoutWrapperRecurrentOutput(self): - keep_some = 0.8 - keep_all = variable_scope.get_variable("all", initializer=1.0) - res = self._testDropoutWrapper( - input_keep_prob=keep_all, - output_keep_prob=keep_some, - state_keep_prob=keep_all, - variational_recurrent=True, - input_size=3, - batch_size=5, - time_steps=7) - # Ensure the same dropout pattern for all time steps - output_mask = np.abs(res[0]) > 1e-6 - for m in output_mask[1:]: - self.assertAllClose(output_mask[0], m) - - def testDropoutWrapperRecurrentStateInputAndOutput(self): - keep_some = 0.9 - res = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - variational_recurrent=True, - input_size=3, - batch_size=5, - time_steps=7) - - # Smoke test for the state/input masks. - output_mask = np.abs(res[0]) > 1e-6 - for time_step in output_mask: - # Ensure the same dropout output pattern for all time steps - self.assertAllClose(output_mask[0], time_step) - for batch_entry in time_step: - # Assert all batch entries get the same mask - self.assertAllClose(batch_entry, time_step[0]) - - # For state, ensure all batch entries have the same mask - state_c_mask = np.abs(res[1].c) > 1e-6 - state_h_mask = np.abs(res[1].h) > 1e-6 - for batch_entry in state_c_mask: - self.assertAllClose(batch_entry, state_c_mask[0]) - for batch_entry in state_h_mask: - self.assertAllClose(batch_entry, state_h_mask[0]) - - def testDropoutWrapperRecurrentStateInputAndOutputWithSeed(self): - keep_some = 0.9 - random_seed.set_random_seed(2347) - np.random.seed(23487) - res0 = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - variational_recurrent=True, - input_size=3, - batch_size=5, - time_steps=7, - seed=-234987) - ops.reset_default_graph() - self._ClearCachedSession() - random_seed.set_random_seed(2347) - np.random.seed(23487) - res1 = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - variational_recurrent=True, - input_size=3, - batch_size=5, - time_steps=7, - seed=-234987) - - output_mask = np.abs(res0[0]) > 1e-6 - for time_step in output_mask: - # Ensure the same dropout output pattern for all time steps - self.assertAllClose(output_mask[0], time_step) - for batch_entry in time_step: - # Assert all batch entries get the same mask - self.assertAllClose(batch_entry, time_step[0]) - - # For state, ensure all batch entries have the same mask - state_c_mask = np.abs(res0[1].c) > 1e-6 - state_h_mask = np.abs(res0[1].h) > 1e-6 - for batch_entry in state_c_mask: - self.assertAllClose(batch_entry, state_c_mask[0]) - for batch_entry in state_h_mask: - self.assertAllClose(batch_entry, state_h_mask[0]) - - # Ensure seeded calculation is identical. - self.assertAllClose(res0[0], res1[0]) - self.assertAllClose(res0[1].c, res1[1].c) - self.assertAllClose(res0[1].h, res1[1].h) - - -def basic_rnn_cell(inputs, state, num_units, scope=None): - if state is None: - if inputs is not None: - batch_size = inputs.get_shape()[0] - dtype = inputs.dtype - else: - batch_size = 0 - dtype = dtypes.float32 - init_output = array_ops.zeros( - array_ops.stack([batch_size, num_units]), dtype=dtype) - init_state = array_ops.zeros( - array_ops.stack([batch_size, num_units]), dtype=dtype) - init_output.set_shape([batch_size, num_units]) - init_state.set_shape([batch_size, num_units]) - return init_output, init_state - else: - with variable_scope.variable_scope(scope, "basic_rnn_cell", - [inputs, state]): - output = math_ops.tanh( - Linear([inputs, state], num_units, True)([inputs, state])) - return output, output - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index aa1d7d2b01b4595bbb03ba8e867e93db759cbd52..dfac2df6a0d4143106ad0f090805597c26659280 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -22,6 +22,7 @@ import itertools import numpy as np +from tensorflow.contrib.rnn.python.ops import core_rnn_cell as legacy_rnn_cell from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session @@ -29,7 +30,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.keras import initializers +from tensorflow.python.keras import layers as keras_layers from tensorflow.python.keras import testing_utils from tensorflow.python.keras import utils from tensorflow.python.ops import array_ops @@ -51,6 +54,294 @@ from tensorflow.python.util import nest class RNNCellTest(test.TestCase): + def testIndRNNCell(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + cell = contrib_rnn_cell.IndRNNCell(2) + g, _ = cell(x, m) + self.assertEqual([ + "root/ind_rnn_cell/%s_w:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/ind_rnn_cell/%s_u:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/ind_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME + ], [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + self.assertEqual(res[0].shape, (1, 2)) + + def testIndyGRUCell(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.185265, 0.17704]]) + with variable_scope.variable_scope( + "other", initializer=init_ops.constant_initializer(0.5)): + # Test IndyGRUCell with input_size != num_units. + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.155127, 0.157328]]) + + def testIndyLSTMCell(self): + for dtype in [dtypes.float16, dtypes.float32]: + np_dtype = dtype.as_numpy_dtype + with self.session(graph=ops.Graph()) as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2], dtype=dtype) + state_0 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + state_1 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + cell = rnn_cell_impl.MultiRNNCell( + [contrib_rnn_cell.IndyLSTMCell(2) for _ in range(2)]) + self.assertEqual(cell.dtype, None) + self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) + self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) + cell.get_config() # Should not throw an error + g, (out_state_0, out_state_1) = cell(x, (state_0, state_1)) + # Layer infers the input type. + self.assertEqual(cell.dtype, dtype.name) + expected_variable_names = [ + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_w:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_u:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s:0" % + rnn_cell_impl._BIAS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_w:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_u:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s:0" % + rnn_cell_impl._BIAS_VARIABLE_NAME + ] + self.assertEqual(expected_variable_names, + [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) + sess.run([variables.global_variables_initializer()]) + res = sess.run( + [g, out_state_0, out_state_1], { + x.name: np.array([[1., 1.]]), + state_0[0].name: 0.1 * np.ones([1, 2]), + state_0[1].name: 0.1 * np.ones([1, 2]), + state_1[0].name: 0.1 * np.ones([1, 2]), + state_1[1].name: 0.1 * np.ones([1, 2]), + }) + self.assertEqual(len(res), 3) + global_variables = variables.global_variables() + self.assertEqual(expected_variable_names, + [v.name for v in global_variables]) + # Only check the range of outputs as this is just a smoke test. + self.assertAllInRange(res[0], -1.0, 1.0) + self.assertAllInRange(res[1], -1.0, 1.0) + self.assertAllInRange(res[2], -1.0, 1.0) + with variable_scope.variable_scope( + "other", initializer=init_ops.constant_initializer(0.5)): + # Test IndyLSTMCell with input_size != num_units. + x = array_ops.zeros([1, 3], dtype=dtype) + state = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + g, out_state = contrib_rnn_cell.IndyLSTMCell(2)(x, state) + sess.run([variables.global_variables_initializer()]) + res = sess.run( + [g, out_state], { + x.name: np.array([[1., 1., 1.]], dtype=np_dtype), + state[0].name: 0.1 * np.ones([1, 2], dtype=np_dtype), + state[1].name: 0.1 * np.ones([1, 2], dtype=np_dtype), + }) + self.assertEqual(len(res), 2) + + def testLSTMCellLayerNorm(self): + with self.cached_session() as sess: + num_units = 2 + num_proj = 3 + batch_size = 1 + input_size = 4 + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([batch_size, input_size]) + c = array_ops.zeros([batch_size, num_units]) + h = array_ops.zeros([batch_size, num_proj]) + state = rnn_cell_impl.LSTMStateTuple(c, h) + cell = contrib_rnn_cell.LayerNormLSTMCell( + num_units=num_units, + num_proj=num_proj, + forget_bias=1.0, + layer_norm=True, + norm_gain=1.0, + norm_shift=0.0) + g, out_m = cell(x, state) + sess.run([variables.global_variables_initializer()]) + res = sess.run( + [g, out_m], { + x.name: np.ones((batch_size, input_size)), + c.name: 0.1 * np.ones((batch_size, num_units)), + h.name: 0.1 * np.ones((batch_size, num_proj)) + }) + self.assertEqual(len(res), 2) + # The numbers in results were not calculated, this is mostly just a + # smoke test. + self.assertEqual(res[0].shape, (batch_size, num_proj)) + self.assertEqual(res[1][0].shape, (batch_size, num_units)) + self.assertEqual(res[1][1].shape, (batch_size, num_proj)) + # Different inputs so different outputs and states + for i in range(1, batch_size): + self.assertTrue( + float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6) + self.assertTrue( + float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6) + + def testOutputProjectionWrapper(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 3]) + cell = legacy_rnn_cell.OutputProjectionWrapper( + rnn_cell_impl.GRUCell(3), 2) + g, new_m = cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g, new_m], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1, 0.1]]) + }) + self.assertEqual(res[1].shape, (1, 3)) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res[0], [[0.231907, 0.231907]]) + + def testInputProjectionWrapper(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 3]) + cell = legacy_rnn_cell.InputProjectionWrapper( + rnn_cell_impl.GRUCell(3), num_proj=3) + g, new_m = cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g, new_m], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1, 0.1]]) + }) + self.assertEqual(res[1].shape, (1, 3)) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) + + def testEmbeddingWrapper(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 1], dtype=dtypes.int32) + m = array_ops.zeros([1, 2]) + embedding_cell = legacy_rnn_cell.EmbeddingWrapper( + rnn_cell_impl.GRUCell(2), embedding_classes=3, embedding_size=2) + self.assertEqual(embedding_cell.output_size, 2) + g, new_m = embedding_cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g, new_m], { + x.name: np.array([[1]]), + m.name: np.array([[0.1, 0.1]]) + }) + self.assertEqual(res[1].shape, (1, 2)) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res[0], [[0.17139, 0.17139]]) + + def testEmbeddingWrapperWithDynamicRnn(self): + with self.cached_session() as sess: + with variable_scope.variable_scope("root"): + inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64) + input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64) + embedding_cell = legacy_rnn_cell.EmbeddingWrapper( + rnn_cell_impl.BasicLSTMCell(1, state_is_tuple=True), + embedding_classes=1, + embedding_size=2) + outputs, _ = rnn.dynamic_rnn( + cell=embedding_cell, + inputs=inputs, + sequence_length=input_lengths, + dtype=dtypes.float32) + sess.run([variables.global_variables_initializer()]) + # This will fail if output's dtype is inferred from input's. + sess.run(outputs) + + def testSRUCell(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.SRUCell(2)(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.509682, 0.509682]]) + + def testSRUCellKerasRNN(self): + """Tests that SRUCell works with keras RNN layer.""" + cell = contrib_rnn_cell.SRUCell(10) + seq_input = ops.convert_to_tensor( + np.random.rand(2, 3, 5), name="seq_input", dtype=dtypes.float32) + rnn_layer = keras_layers.RNN(cell=cell) + rnn_outputs_keras = rnn_layer(seq_input) + with self.cached_session() as sess: + sess.run([variables.global_variables_initializer()]) + self.assertEqual(sess.run(rnn_outputs_keras).shape, (2, 10)) + + def testSRUCellBiasType(self): + """Tests that the bias' dtype is properly set.""" + cell = contrib_rnn_cell.SRUCell(10) + cell.build((2, 3, 5)) + self.assertEqual(cell._bias.dtype, dtypes.float32_ref) + + cell = contrib_rnn_cell.SRUCell(10, dtype=dtypes.int32) + cell.build((2, 3, 5)) + self.assertEqual(cell._bias.dtype, dtypes.int32_ref) + + cell_input = ops.convert_to_tensor( + np.random.rand(2, 5), name="cell_input", dtype=dtypes.float16) + cell_state = ops.convert_to_tensor( + np.random.rand(2, 10), name="cell_state", dtype=dtypes.float16) + cell = contrib_rnn_cell.SRUCell(10) + cell(cell_input, [cell_state]) + self.assertEqual(cell._bias.dtype, dtypes.float16_ref) + + def testSRUCellWithDiffSize(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.SRUCell(2)(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.55255556, 0.55255556]]) + def testCoupledInputForgetGateLSTMCell(self): with self.cached_session() as sess: num_units = 2 @@ -763,6 +1054,17 @@ class RNNCellTest(test.TestCase): self.assertEqual(new_h.shape[1], num_proj) self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) + @test_util.run_in_graph_and_eager_modes + def testNASCellKerasRNN(self): + """Tests that NASCell works with keras RNN layer.""" + cell = contrib_rnn_cell.NASCell(10) + seq_input = ops.convert_to_tensor( + np.random.rand(2, 3, 5), name="seq_input", dtype=dtypes.float32) + rnn_layer = keras_layers.RNN(cell=cell) + rnn_outputs = rnn_layer(seq_input) + self.evaluate([variables.global_variables_initializer()]) + self.assertEqual(self.evaluate(rnn_outputs).shape, (2, 10)) + def testUGRNNCell(self): num_units = 2 batch_size = 3 diff --git a/tensorflow/contrib/rnn/python/ops/rnn.py b/tensorflow/contrib/rnn/python/ops/rnn.py index 0266b72dcb15e4aba01a9a31b4be75c5b84d44da..41b1698321e20f4360d75fa2db79f9bd8a806cea 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn.py +++ b/tensorflow/contrib/rnn/python/ops/rnn.py @@ -131,7 +131,8 @@ def stack_bidirectional_dynamic_rnn(cells_fw, sequence_length=None, parallel_iterations=None, time_major=False, - scope=None): + scope=None, + swap_memory=False): """Creates a dynamic bidirectional recurrent neural network. Stacks several bidirectional rnn layers. The combined forward and backward @@ -171,6 +172,10 @@ def stack_bidirectional_dynamic_rnn(cells_fw, data is batch-major, so by default this function accepts input and emits output in batch-major form. scope: VariableScope for the created subgraph; defaults to None. + swap_memory: Transparently swap the tensors produced in forward inference + but needed for back prop from GPU to CPU. This allows training RNNs + which would typically not fit on a single GPU, with very minimal (or no) + performance penalty. Returns: A tuple (outputs, output_state_fw, output_state_bw) where: @@ -230,6 +235,7 @@ def stack_bidirectional_dynamic_rnn(cells_fw, sequence_length=sequence_length, parallel_iterations=parallel_iterations, dtype=dtype, + swap_memory=swap_memory, time_major=time_major) # Concat the outputs to create the new input. prev_layer = array_ops.concat(outputs, 2) diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 8a1c09f171e6108174671e3122d5ff4c0b236003..d25afc8b9c4381fb3b0092ef21f46646353e1b8e 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -1462,7 +1462,7 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): return new_h, new_state -class NASCell(rnn_cell_impl.RNNCell): +class NASCell(rnn_cell_impl.LayerRNNCell): """Neural Architecture Search (NAS) recurrent network cell. This implements the recurrent cell from the paper: @@ -1475,23 +1475,28 @@ class NASCell(rnn_cell_impl.RNNCell): The class uses an optional projection layer. """ - def __init__(self, num_units, num_proj=None, use_biases=False, reuse=None): + # NAS cell's architecture base. + _NAS_BASE = 8 + + def __init__(self, num_units, num_proj=None, use_bias=False, reuse=None, + **kwargs): """Initialize the parameters for a NAS cell. Args: - num_units: int, The number of units in the NAS cell + num_units: int, The number of units in the NAS cell. num_proj: (optional) int, The output dimensionality for the projection matrices. If None, no projection is performed. - use_biases: (optional) bool, If True then use biases within the cell. This + use_bias: (optional) bool, If True then use biases within the cell. This is False by default. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. + **kwargs: Additional keyword arguments. """ - super(NASCell, self).__init__(_reuse=reuse) + super(NASCell, self).__init__(_reuse=reuse, **kwargs) self._num_units = num_units self._num_proj = num_proj - self._use_biases = use_biases + self._use_bias = use_bias self._reuse = reuse if num_proj is not None: @@ -1509,6 +1514,33 @@ class NASCell(rnn_cell_impl.RNNCell): def output_size(self): return self._output_size + def build(self, inputs_shape): + input_size = tensor_shape.dimension_value( + tensor_shape.TensorShape(inputs_shape).with_rank(2)[1]) + if input_size is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + + num_proj = self._num_units if self._num_proj is None else self._num_proj + + # Variables for the NAS cell. `recurrent_kernel` is all matrices multiplying + # the hiddenstate and `kernel` is all matrices multiplying the inputs. + self.recurrent_kernel = self.add_variable( + "recurrent_kernel", [num_proj, self._NAS_BASE * self._num_units]) + self.kernel = self.add_variable( + "kernel", [input_size, self._NAS_BASE * self._num_units]) + + if self._use_bias: + self.bias = self.add_variable("bias", + shape=[self._NAS_BASE * self._num_units], + initializer=init_ops.zeros_initializer) + + # Projection layer if specified + if self._num_proj is not None: + self.projection_weights = self.add_variable( + "projection_weights", [self._num_units, self._num_proj]) + + self.built = True + def call(self, inputs, state): """Run one step of NAS Cell. @@ -1535,38 +1567,20 @@ class NASCell(rnn_cell_impl.RNNCell): tanh = math_ops.tanh relu = nn_ops.relu - num_proj = self._num_units if self._num_proj is None else self._num_proj - (c_prev, m_prev) = state - dtype = inputs.dtype - input_size = inputs.get_shape().with_rank(2).dims[1] - if input_size.value is None: - raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - # Variables for the NAS cell. W_m is all matrices multiplying the - # hiddenstate and W_inputs is all matrices multiplying the inputs. - concat_w_m = vs.get_variable("recurrent_kernel", - [num_proj, 8 * self._num_units], dtype) - concat_w_inputs = vs.get_variable( - "kernel", [input_size.value, 8 * self._num_units], dtype) - - m_matrix = math_ops.matmul(m_prev, concat_w_m) - inputs_matrix = math_ops.matmul(inputs, concat_w_inputs) - - if self._use_biases: - b = vs.get_variable( - "bias", - shape=[8 * self._num_units], - initializer=init_ops.zeros_initializer(), - dtype=dtype) - m_matrix = nn_ops.bias_add(m_matrix, b) + m_matrix = math_ops.matmul(m_prev, self.recurrent_kernel) + inputs_matrix = math_ops.matmul(inputs, self.kernel) + + if self._use_bias: + m_matrix = nn_ops.bias_add(m_matrix, self.bias) # The NAS cell branches into 8 different splits for both the hiddenstate # and the input m_matrix_splits = array_ops.split( - axis=1, num_or_size_splits=8, value=m_matrix) + axis=1, num_or_size_splits=self._NAS_BASE, value=m_matrix) inputs_matrix_splits = array_ops.split( - axis=1, num_or_size_splits=8, value=inputs_matrix) + axis=1, num_or_size_splits=self._NAS_BASE, value=inputs_matrix) # First layer layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0]) @@ -1598,9 +1612,7 @@ class NASCell(rnn_cell_impl.RNNCell): # Projection layer if specified if self._num_proj is not None: - concat_w_proj = vs.get_variable("projection_weights", - [self._num_units, self._num_proj], dtype) - new_m = math_ops.matmul(new_m, concat_w_proj) + new_m = math_ops.matmul(new_m, self.projection_weights) new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m) return new_m, new_state @@ -2071,7 +2083,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell): conv_ndims: Convolution dimensionality (1, 2 or 3). input_shape: Shape of the input as int tuple, excluding the batch size. output_channels: int, number of output channels of the conv LSTM. - kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3). + kernel_shape: Shape of kernel as an int tuple (of size 1, 2 or 3). use_bias: (bool) Use bias in convolutions. skip_connection: If set to `True`, concatenate the input to the output of the conv LSTM. Default: `False`. @@ -2092,7 +2104,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell): self._conv_ndims = conv_ndims self._input_shape = input_shape self._output_channels = output_channels - self._kernel_shape = kernel_shape + self._kernel_shape = list(kernel_shape) self._use_bias = use_bias self._forget_bias = forget_bias self._skip_connection = skip_connection @@ -2172,7 +2184,7 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0): Args: args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D, batch x n, Tensors. - filter_size: int tuple of filter height and width. + filter_size: int tuple of filter shape (of size 1, 2 or 3). num_features: int, number of features. bias: Whether to use biases in the convolution layer. bias_start: starting value to initialize the bias; 0 by default. @@ -2744,10 +2756,12 @@ class SRUCell(rnn_cell_impl.LayerRNNCell): name: (optional) String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases. + **kwargs: Additional keyword arguments. """ - def __init__(self, num_units, activation=None, reuse=None, name=None): - super(SRUCell, self).__init__(_reuse=reuse, name=name) + def __init__(self, num_units, activation=None, reuse=None, name=None, + **kwargs): + super(SRUCell, self).__init__(_reuse=reuse, name=name, **kwargs) self._num_units = num_units self._activation = activation or math_ops.tanh @@ -2777,7 +2791,7 @@ class SRUCell(rnn_cell_impl.LayerRNNCell): self._bias = self.add_variable( rnn_cell_impl._BIAS_VARIABLE_NAME, # pylint: disable=protected-access shape=[2 * self._num_units], - initializer=init_ops.constant_initializer(0.0, dtype=self.dtype)) + initializer=init_ops.zeros_initializer) self._built = True @@ -3139,7 +3153,7 @@ class IndyGRUCell(rnn_cell_impl.LayerRNNCell): r"""Independently Gated Recurrent Unit cell. Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to GRUCell, - yet with the \(U_r\), \(U_z\), and \(U\) matrices in equations 5, 6, and + yet with the \\(U_r\\), \\(U_z\\), and \\(U\\) matrices in equations 5, 6, and 8 of http://arxiv.org/abs/1406.1078 respectively replaced by diagonal matrices, i.e. a Hadamard product with a single vector: @@ -3150,12 +3164,10 @@ class IndyGRUCell(rnn_cell_impl.LayerRNNCell): $$\tilde{h}^{(t)}_j = \phi\left([\mathbf W \mathbf x]_j + [\mathbf u \circ \mathbf r \circ \mathbf h_{(t-1)}]_j\right)$$ - where \(\circ\) denotes the Hadamard operator. This means that each IndyGRU + where \\(\circ\\) denotes the Hadamard operator. This means that each IndyGRU node sees only its own state, as opposed to seeing all states in the same layer. - TODO(gonnet): Write a paper describing this and add a reference here. - Args: num_units: int, The number of units in the GRU cell. activation: Nonlinearity to use. Default: `tanh`. @@ -3240,7 +3252,7 @@ class IndyGRUCell(rnn_cell_impl.LayerRNNCell): self.built = True def call(self, inputs, state): - """Gated recurrent unit (GRU) with nunits cells.""" + """Recurrently independent Gated Recurrent Unit (GRU) with nunits cells.""" gate_inputs = math_ops.matmul(inputs, self._gate_kernel_w) + ( gen_array_ops.tile(state, [1, 2]) * self._gate_kernel_u) @@ -3264,10 +3276,9 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): r"""Basic IndyLSTM recurrent network cell. Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to - BasicLSTMCell, yet with the \(U_f\), \(U_i\), \(U_o\) and \(U_c\) - matrices in - https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate - replaced by diagonal matrices, i.e. a Hadamard product with a single vector: + BasicLSTMCell, yet with the \\(U_f\\), \\(U_i\\), \\(U_o\\) and \\(U_c\\) + matrices in the regular LSTM equations replaced by diagonal matrices, i.e. a + Hadamard product with a single vector: $$f_t = \sigma_g\left(W_f x_t + u_f \circ h_{t-1} + b_f\right)$$ $$i_t = \sigma_g\left(W_i x_t + u_i \circ h_{t-1} + b_i\right)$$ @@ -3275,8 +3286,8 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): $$c_t = f_t \circ c_{t-1} + i_t \circ \sigma_c\left(W_c x_t + u_c \circ h_{t-1} + b_c\right)$$ - where \(\circ\) denotes the Hadamard operator. This means that each IndyLSTM - node sees only its own state \(h\) and \(c\), as opposed to seeing all + where \\(\circ\\) denotes the Hadamard operator. This means that each IndyLSTM + node sees only its own state \\(h\\) and \\(c\\), as opposed to seeing all states in the same layer. We add forget_bias (default: 1) to the biases of the forget gate in order to @@ -3284,11 +3295,6 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline. - - For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell` - that follows. - - TODO(gonnet): Write a paper describing this and add a reference here. """ def __init__(self, diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py index 3fc6bfbb4d03a39906d4441e48b2788423caa234..d8ab9eba7049e468b373a1641f92dc781aa22558 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py @@ -61,10 +61,7 @@ class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase): self._server = server def tearDown(self): - # TODO(ebrevdo): Figure out why this sometimes times out. - # self._service.ExitLoop() - # self._service_thread.join() - # self._server.stop() + self._server.stop(grace=None) super(RpcOpTest, self).tearDown() diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py index 0d615923e04915a8429252317025ac8e79f9bb4e..d6148715be91c78e6e5a99fc0f3caa905b5c1a7d 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py @@ -176,7 +176,9 @@ class RpcOpTestBase(object): expected_message_values = np.where( status_code_values == errors.INVALID_ARGUMENT, I_WARNED_YOU.encode('ascii'), b'') - self.assertAllEqual(expected_message_values, status_message_values) + for msg, expected in zip(status_message_values, expected_message_values): + self.assertTrue(expected in msg, + '"%s" did not contain "%s"' % (msg, expected)) def testVecHostPortRpc(self): with self.cached_session() as sess: diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index 269443b2c6508bb618d30f64487b1a6a84e8646f..f0242a3b40fd566ec0f477d462426d5f550d1620 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -84,35 +84,6 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lib", - "//tensorflow/python:metrics", - "//tensorflow/python:platform", - "//tensorflow/python:saver", - "//tensorflow/python:util", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/keras:engine", - "//tensorflow/python/saved_model", - ], -) - -py_test( - name = "keras_saved_model_test", - size = "medium", - srcs = ["python/saved_model/keras_saved_model_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_oss", # TODO(b/119349471): Re-enable - "no_windows", - ], - deps = [ - ":keras_saved_model", - "//tensorflow/python:client_testlib", - "//tensorflow/python:training", - "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py index ffba514bb96f5ce8d963cb0a0482738eafe88355..a61e9579b84a60d74b73e45a6100a2c772d9cff8 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -18,348 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import six +from tensorflow.python.keras import saving -from tensorflow.python.client import session -from tensorflow.python.estimator import keras as estimator_keras_util -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.export import export as export_helpers -from tensorflow.python.framework import ops -from tensorflow.python.keras import backend as K -from tensorflow.python.keras import models as models_lib -from tensorflow.python.keras import optimizers -from tensorflow.python.keras.engine import sequential -from tensorflow.python.keras.metrics import Metric -from tensorflow.python.keras.models import model_from_json -from tensorflow.python.lib.io import file_io -from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.saved_model import builder as saved_model_builder -from tensorflow.python.saved_model import constants -from tensorflow.python.saved_model import utils_impl as saved_model_utils -from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training.checkpointable import util as checkpointable_utils -from tensorflow.python.util import compat - -def save_keras_model( - model, saved_model_path, custom_objects=None, as_text=None): - """Save a `tf.keras.Model` into Tensorflow SavedModel format. - - `save_model` generates new files/folders under the `saved_model_path` folder: - 1) an asset folder containing the json string of the model's - configuration (topology). - 2) a checkpoint containing the model weights. - 3) a saved_model.pb file containing the model's MetaGraphs. The prediction - graph is always exported. The evaluaton and training graphs are exported - if the following conditions are met: - - Evaluation: model loss is defined. - - Training: model is compiled with an optimizer defined under `tf.train`. - This is because `tf.keras.optimizers.Optimizer` instances cannot be - saved to checkpoints. - - Model Requirements: - - Model must be a sequential model or functional model. Subclassed models can - not be saved via this function, unless you provide an implementation for - get_config() and from_config(). - - All variables must be saveable by the model. In general, this condition is - met through the use of layers defined in the keras library. However, - there is currently a bug with variables created in Lambda layer functions - not being saved correctly (see - https://github.com/keras-team/keras/issues/9740). - - Note that each mode is exported in separate graphs, so different modes do not - share variables. To use the train graph with evaluation or prediction graphs, - create a new checkpoint if variable values have been updated. - - Example: - - ```python - import tensorflow as tf - - # Create a tf.keras model. - model = tf.keras.Sequential() - model.add(tf.keras.layers.Dense(1, input_shape=[10])) - model.summary() - - # Save the tf.keras model in the SavedModel format. - saved_to_path = tf.contrib.saved_model.save_keras_model( - model, '/tmp/my_simple_tf_keras_saved_model') - - # Load the saved keras model back. - model_prime = tf.contrib.saved_model.load_keras_model(saved_to_path) - model_prime.summary() - ``` - - Args: - model: A `tf.keras.Model` to be saved. - saved_model_path: a string specifying the path to the SavedModel directory. - The SavedModel will be saved to a timestamped folder created within this - directory. - custom_objects: Optional dictionary mapping string names to custom classes - or functions (e.g. custom loss functions). - as_text: whether to write the `SavedModel` proto in text format. - - Returns: - String path to the SavedModel folder, a subdirectory of `saved_model_path`. - - Raises: - NotImplementedError: If the model is a subclassed model. - ValueError: If a Sequential model does not have input shapes defined by the - user, and is not built. - """ - if not model._is_graph_network: - if isinstance(model, sequential.Sequential): - # If input shape is not directly set in the model, the exported model - # will assume that the inputs have the same shape as the shape the model - # was built model with. - if not model.built: - raise ValueError( - 'Sequential model must be built before it can be exported.') - else: - raise NotImplementedError( - 'Exporting subclassed models is not yet supported.') - - export_dir = export_helpers.get_timestamped_export_dir(saved_model_path) - temp_export_dir = export_helpers.get_temp_export_dir(export_dir) - - builder = saved_model_builder._SavedModelBuilder(temp_export_dir) - - # Manually save variables to export them in an object-based checkpoint. This - # skips the `builder.add_meta_graph_and_variables()` step, which saves a - # named-based checkpoint. - # TODO(b/113134168): Add fn to Builder to save with object-based saver. - # TODO(b/113178242): This should only export the model json structure. Only - # one save is needed once the weights can be copied from the model to clone. - checkpoint_path = _export_model_json_and_variables(model, temp_export_dir) - - # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that - # Keras models and `Estimator`s are exported with the same format. - # Every time a mode is exported, the code checks to see if new variables have - # been created (e.g. optimizer slot variables). If that is the case, the - # checkpoint is re-saved to include the new variables. - export_args = {'builder': builder, - 'model': model, - 'custom_objects': custom_objects, - 'checkpoint_path': checkpoint_path} - - has_saved_vars = False - if model.optimizer: - if isinstance(model.optimizer, optimizers.TFOptimizer): - _export_mode(model_fn_lib.ModeKeys.TRAIN, has_saved_vars, **export_args) - has_saved_vars = True - _export_mode(model_fn_lib.ModeKeys.EVAL, has_saved_vars, **export_args) - else: - logging.warning( - 'Model was compiled with an optimizer, but the optimizer is not from ' - '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving ' - 'graph was exported. The train and evaluate graphs were not added to ' - 'the SavedModel.') - _export_mode(model_fn_lib.ModeKeys.PREDICT, has_saved_vars, **export_args) - - builder.save(as_text) - - gfile.Rename(temp_export_dir, export_dir) - return export_dir - - -def _export_model_json_and_variables(model, saved_model_path): - """Save model variables and json structure into SavedModel subdirectories.""" - # Save model configuration as a json string under assets folder. - model_json = model.to_json() - model_json_filepath = os.path.join( - saved_model_utils.get_or_create_assets_dir(saved_model_path), - compat.as_text(constants.SAVED_MODEL_FILENAME_JSON)) - file_io.write_string_to_file(model_json_filepath, model_json) - - # Save model weights in checkpoint format under variables folder. - saved_model_utils.get_or_create_variables_dir(saved_model_path) - checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path) - model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) - return checkpoint_prefix - - -def _get_var_list(model): - """Return list of all checkpointed saveable objects in the model.""" - return checkpointable_utils.named_saveables(model) - - -def _export_mode( - mode, has_saved_vars, builder, model, custom_objects, checkpoint_path): - """Export a model, and optionally save new vars from the clone model. - - Args: - mode: A `tf.estimator.ModeKeys` string. - has_saved_vars: A `boolean` indicating whether the SavedModel has already - exported variables. - builder: A `SavedModelBuilder` object. - model: A `tf.keras.Model` object. - custom_objects: A dictionary mapping string names to custom classes - or functions. - checkpoint_path: String path to checkpoint. - - Raises: - ValueError: If the train/eval mode is being exported, but the model does - not have an optimizer. - """ - compile_clone = (mode != model_fn_lib.ModeKeys.PREDICT) - if compile_clone and not model.optimizer: - raise ValueError( - 'Model does not have an optimizer. Cannot export mode %s' % mode) - - model_graph = ops.get_default_graph() - with ops.Graph().as_default() as g: - - K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN) - - # Clone the model into blank graph. This will create placeholders for inputs - # and targets. - clone = models_lib.clone_and_build_model( - model, custom_objects=custom_objects, compile_clone=compile_clone) - - # Make sure that iterations variable is added to the global step collection, - # to ensure that, when the SavedModel graph is loaded, the iterations - # variable is returned by `tf.train.get_global_step()`. This is required for - # compatibility with the SavedModelEstimator. - if compile_clone: - g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations) - - # Extract update and train ops from train/test/predict functions. - train_op = None - if mode == model_fn_lib.ModeKeys.TRAIN: - clone._make_train_function() - train_op = clone.train_function.updates_op - elif mode == model_fn_lib.ModeKeys.EVAL: - clone._make_test_function() - else: - clone._make_predict_function() - g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates) - - clone_var_list = checkpointable_utils.named_saveables(clone) - - with session.Session().as_default(): - if has_saved_vars: - # Confirm all variables in the clone have an entry in the checkpoint. - status = clone.load_weights(checkpoint_path) - status.assert_existing_objects_matched() - else: - # Confirm that variables between the clone and model match up exactly, - # not counting optimizer objects. Optimizer objects are ignored because - # if the model has not trained, the slot variables will not have been - # created yet. - # TODO(b/113179535): Replace with checkpointable equivalence. - _assert_same_non_optimizer_objects(model, model_graph, clone, g) - - # TODO(b/113178242): Use value transfer for checkpointable objects. - clone.load_weights(checkpoint_path) - - # Add graph and variables to SavedModel. - # TODO(b/113134168): Switch to add_meta_graph_and_variables. - clone.save_weights(checkpoint_path, save_format='tf', overwrite=True) - builder._has_saved_variables = True - - # Add graph to the SavedModel builder. - builder.add_meta_graph( - model_fn_lib.EXPORT_TAG_MAP[mode], - signature_def_map=_create_signature_def_map(clone, mode), - saver=saver_lib.Saver(clone_var_list), - init_op=variables.local_variables_initializer(), - train_op=train_op) - return None - - -def _create_signature_def_map(model, mode): - """Create a SignatureDef map from a Keras model.""" - inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)} - if model.optimizer: - targets_dict = {x.name.split(':')[0]: x - for x in model.targets if x is not None} - inputs_dict.update(targets_dict) - outputs_dict = {name: x - for name, x in zip(model.output_names, model.outputs)} - metrics = estimator_keras_util._convert_keras_metrics_to_estimator(model) - - # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables - # are by default not added to any collections. We are doing this here, so - # that metric variables get initialized. - local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) - vars_to_add = set() - if metrics is not None: - for key, value in six.iteritems(metrics): - if isinstance(value, Metric): - vars_to_add.update(value.variables) - # Convert Metric instances to (value_tensor, update_op) tuple. - metrics[key] = (value.result(), value.updates[0]) - # Remove variables that are in the local variables collection already. - vars_to_add = vars_to_add.difference(local_vars) - for v in vars_to_add: - ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v) - - export_outputs = model_fn_lib.export_outputs_for_mode( - mode, - predictions=outputs_dict, - loss=model.total_loss if model.optimizer else None, - metrics=metrics) - return export_helpers.build_all_signature_defs( - inputs_dict, - export_outputs=export_outputs, - serving_only=(mode == model_fn_lib.ModeKeys.PREDICT)) - - -def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument - """Assert model and clone contain the same checkpointable objects.""" - - # TODO(fchollet, kathywu): make sure this works in eager mode. - return True - - -def load_keras_model(saved_model_path): - """Load a keras.Model from SavedModel. - - load_model reinstantiates model state by: - 1) loading model topology from json (this will eventually come - from metagraph). - 2) loading model weights from checkpoint. - - Example: - - ```python - import tensorflow as tf - - # Create a tf.keras model. - model = tf.keras.Sequential() - model.add(tf.keras.layers.Dense(1, input_shape=[10])) - model.summary() - - # Save the tf.keras model in the SavedModel format. - saved_to_path = tf.contrib.saved_model.save_keras_model( - model, '/tmp/my_simple_tf_keras_saved_model') - - # Load the saved keras model back. - model_prime = tf.contrib.saved_model.load_keras_model(saved_to_path) - model_prime.summary() - ``` - - Args: - saved_model_path: a string specifying the path to an existing SavedModel. - - Returns: - a keras.Model instance. - """ - # restore model topology from json string - model_json_filepath = os.path.join( - compat.as_bytes(saved_model_path), - compat.as_bytes(constants.ASSETS_DIRECTORY), - compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) - model_json = file_io.read_file_to_string(model_json_filepath) - model = model_from_json(model_json) - - # restore model weights - checkpoint_prefix = os.path.join( - compat.as_text(saved_model_path), - compat.as_text(constants.VARIABLES_DIRECTORY), - compat.as_text(constants.VARIABLES_FILENAME)) - model.load_weights(checkpoint_prefix) - return model +# TODO(kathywu): Remove all contrib callers, switch to tf.keras. +save_keras_model = saving.export_saved_model +load_keras_model = saving.load_from_saved_model diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index 18b56cd21942e28cb0dc3210df0bb04d55c1e16f..8e2ce82294287dda07d2067c5b9f012f510dbd08 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -33,7 +33,6 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ ":beam_search_ops", - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/contrib/util:util_py", @@ -59,7 +58,6 @@ tf_custom_op_py_library( "//tensorflow/python:tensor_util", "//tensorflow/python:util", "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/distributions", "//third_party/py/numpy", "@six_archive//:six", ], @@ -141,6 +139,27 @@ cuda_py_test( ], ) +cuda_py_test( + name = "basic_decoder_v2_test", + size = "medium", + srcs = ["python/kernel_tests/basic_decoder_v2_test.py"], + additional_deps = [ + ":seq2seq_py", + "@absl_py//absl/testing:parameterized", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:rnn", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "beam_search_ops_test", size = "medium", @@ -175,6 +194,27 @@ cuda_py_test( ], ) +cuda_py_test( + name = "decoder_v2_test", + size = "medium", + srcs = ["python/kernel_tests/decoder_v2_test.py"], + additional_deps = [ + ":seq2seq_py", + "@absl_py//absl/testing:parameterized", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:rnn", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "beam_search_decoder_test", size = "medium", @@ -215,3 +255,19 @@ cuda_py_test( "//tensorflow/python:variables", ], ) + +cuda_py_test( + name = "attention_wrapper_v2_test", + size = "medium", + srcs = ["python/kernel_tests/attention_wrapper_v2_test.py"], + additional_deps = [ + ":seq2seq_py", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], + shard_count = 4, +) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 922f21b98b35dfff19c8c605a25e89c5d2da8d98..1a5692f7b5be5e87b78dac9d1ae51f280ca089f8 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope as vs @@ -357,7 +358,7 @@ class AttentionWrapperTest(test.TestCase): rnn_output=ResultSummary( shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00597103), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=1.6)) + shape=(5, 3), dtype=dtype('int32'), mean=1.4)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -386,7 +387,7 @@ class AttentionWrapperTest(test.TestCase): rnn_output=ResultSummary( shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0052615386), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=1.3333333333)) + shape=(5, 3), dtype=dtype('int32'), mean=1.4)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -453,7 +454,7 @@ class AttentionWrapperTest(test.TestCase): rnn_output=ResultSummary( shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0052615386), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=1.3333333333333333)) + shape=(5, 3), dtype=dtype('int32'), mean=1.4)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -695,7 +696,7 @@ class AttentionWrapperTest(test.TestCase): rnn_output=ResultSummary( shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0025896581), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=1.6)) + shape=(5, 3), dtype=dtype('int32'), mean=1.73333333)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -706,12 +707,12 @@ class AttentionWrapperTest(test.TestCase): shape=(5, 6), dtype=dtype('float32'), mean=-0.00069823361), time=3, alignments=ResultSummary( - shape=(5, 8), dtype=dtype('float32'), mean=0.028698336), + shape=(5, 8), dtype=dtype('float32'), mean=0.029914695), attention_state=ResultSummary( - shape=(5, 8), dtype=dtype('float32'), mean=0.028698336), + shape=(5, 8), dtype=dtype('float32'), mean=0.029914695), alignment_history=()) expected_final_alignment_history = ResultSummary( - shape=(3, 5, 8), dtype=dtype('float32'), mean=0.04865776002407074) + shape=(3, 5, 8), dtype=dtype('float32'), mean=0.0465225502849) self._testWithAttention( create_attention_mechanism, @@ -920,9 +921,9 @@ class AttentionWrapperTest(test.TestCase): expected_final_output = BasicDecoderOutput( rnn_output=ResultSummary( - shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11723966), + shape=(5, 3, 20), dtype=dtype('float32'), mean=0.115853324533), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=7.266666666666667)) + shape=(5, 3), dtype=dtype('int32'), mean=8.6)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -930,7 +931,7 @@ class AttentionWrapperTest(test.TestCase): h=ResultSummary( shape=(5, 9), dtype=dtype('float32'), mean=-0.0018327223)), attention=ResultSummary( - shape=(5, 20), dtype=dtype('float32'), mean=0.11601614207), + shape=(5, 20), dtype=dtype('float32'), mean=0.11462739855), time=3, alignments=(ResultSummary( shape=(5, 8), dtype=dtype('float32'), mean=0.125), @@ -992,5 +993,67 @@ class AttentionWrapperTest(test.TestCase): expected_final_alignment_history=expected_final_alignment_history, name='testMultiAttention') + def testCustomizedAttention(self): + batch_size = 2 + max_time = 3 + num_units = 2 + memory = constant_op.constant([[[1., 1.], [2., 2.], [3., 3.]], + [[4., 4.], [5., 5.], [6., 6.]]]) + memory_sequence_length = constant_op.constant([3, 2]) + attention_mechanism = wrapper.BahdanauAttention(num_units, memory, + memory_sequence_length) + + # Sets all returned values to be all ones. + def _customized_attention(unused_attention_mechanism, unused_cell_output, + unused_attention_state, unused_attention_layer): + """Customized attention. + + Returns: + attention: `Tensor` of shape [batch_size, num_units], attention output. + alignments: `Tensor` of shape [batch_size, max_time], sigma value for + each input memory (prob. function of input keys). + next_attention_state: A `Tensor` representing the next state for the + attention. + """ + attention = array_ops.ones([batch_size, num_units]) + alignments = array_ops.ones([batch_size, max_time]) + next_attention_state = alignments + return attention, alignments, next_attention_state + + attention_cell = wrapper.AttentionWrapper( + rnn_cell.LSTMCell(2), + attention_mechanism, + attention_layer_size=None, # don't use attention layer. + output_attention=False, + alignment_history=(), + attention_fn=_customized_attention, + name='attention') + self.assertEqual(num_units, attention_cell.output_size) + + initial_state = attention_cell.zero_state( + batch_size=2, dtype=dtypes.float32) + source_input_emb = array_ops.ones([2, 3, 2]) + source_input_length = constant_op.constant([3, 2]) + + # 'state' is a tuple of + # (cell_state, h, attention, alignments, alignment_history, attention_state) + output, state = rnn.dynamic_rnn( + attention_cell, + inputs=source_input_emb, + sequence_length=source_input_length, + initial_state=initial_state, + dtype=dtypes.float32) + + with self.session() as sess: + sess.run(variables.global_variables_initializer()) + output_value, state_value = sess.run([output, state], feed_dict={}) + self.assertAllEqual(np.array([2, 3, 2]), output_value.shape) + self.assertAllClose(np.array([[1., 1.], [1., 1.]]), state_value.attention) + self.assertAllClose( + np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.alignments) + self.assertAllClose( + np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.attention_state) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5ee01f66f165bd2ac22cae10807f24f6b97f0c64 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py @@ -0,0 +1,745 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib.seq2seq.python.ops.attention_wrapper.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.seq2seq.python.ops import attention_wrapper as wrapper +from tensorflow.contrib.seq2seq.python.ops import basic_decoder +from tensorflow.contrib.seq2seq.python.ops import sampler as sampler_py +from tensorflow.python import keras +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.keras import initializers +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.util import nest + + +@test_util.run_all_in_graph_and_eager_modes +class AttentionMechanismTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + super(AttentionMechanismTest, self).setUp() + self.batch = 10 + self.timestep = 5 + self.memory_size = 6 + self.units = 8 + + self.memory = np.random.randn(self.batch, self.timestep, + self.memory_size).astype(np.float32) + self.query = np.random.randn(self.batch, self.units).astype(np.float32) + self.state = np.random.randn(self.batch, self.timestep).astype(np.float32) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_attention_shape_inference(self, attention_cls): + attention = attention_cls(self.units, self.memory) + attention_score = attention([self.query, self.state]) + self.assertLen(attention_score, 2) + self.assertEqual(attention_score[0].shape, (self.batch, self.timestep)) + self.assertEqual(attention_score[1].shape, (self.batch, self.timestep)) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_get_config(self, attention_cls): + attention = attention_cls(self.units, self.memory) + config = attention.get_config() + + attention_from_config = attention_cls.from_config(config) + config_from_clone = attention_from_config.get_config() + + self.assertDictEqual(config, config_from_clone) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_layer_output(self, attention_cls): + attention = attention_cls(self.units, self.memory) + score = attention([self.query, self.state]) + self.evaluate(variables.variables_initializer(attention.variables)) + + score_val = self.evaluate(score) + self.assertLen(score_val, 2) + self.assertEqual(score_val[0].shape, (self.batch, self.timestep)) + self.assertEqual(score_val[1].shape, (self.batch, self.timestep)) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_passing_memory_from_call(self, attention_cls): + attention = attention_cls(self.units, self.memory) + weights_before_query = attention.get_weights() + ref_score = attention([self.query, self.state]) + + self.evaluate(variables.global_variables_initializer()) + ref_score_val = self.evaluate(ref_score) + + all_weights = attention.get_weights() + config = attention.get_config() + # Simulate the twice invocation of calls here. + attention_from_config = attention_cls.from_config(config) + attention_from_config.build(self.memory.shape) + attention_from_config.set_weights(weights_before_query) + attention_from_config(self.memory, setup_memory=True) + attention_from_config.build([self.query.shape, self.state.shape]) + attention_from_config.set_weights(all_weights) + score = attention_from_config([self.query, self.state]) + + score_val = self.evaluate(score) + self.assertAllClose(ref_score_val, score_val) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_save_load_layer(self, attention_cls): + vocab = 20 + embedding_dim = 6 + inputs = keras.layers.Input(shape=[self.timestep]) + encoder_input = keras.layers.Embedding( + vocab, embedding_dim, mask_zero=True)( + inputs) + encoder_output = keras.layers.UnifiedLSTM( + self.memory_size, return_sequences=True)( + encoder_input) + + attention = attention_cls(self.units, encoder_output) + query = keras.layers.Input(shape=[self.units]) + state = keras.layers.Input(shape=[self.timestep]) + + score = attention([query, state]) + + x = np.random.randint(vocab, size=(self.batch, self.timestep)) + x_test = np.random.randint(vocab, size=(self.batch, self.timestep)) + y = np.random.randn(self.batch, self.timestep) + model = keras.models.Model([inputs, query, state], score) + model.compile("rmsprop", "mse") + model.fit([x, self.query, self.state], (y, y)) + y_ref = model.predict_on_batch([x_test, self.query, self.state]) + + config = model.get_config() + weights = model.get_weights() + loaded_model = keras.models.Model.from_config( + config, custom_objects={attention_cls.__name__: attention_cls}) + loaded_model.set_weights(weights) + + y = loaded_model.predict_on_batch([x_test, self.query, self.state]) + + self.assertAllClose(y_ref, y) + + # TODO(scottzhu): Add tests for model.compile(run_eagerly=True) + + +class ResultSummary( + collections.namedtuple("ResultSummary", ("shape", "dtype", "mean"))): + pass + + +def get_result_summary(x): + if isinstance(x, np.ndarray): + return ResultSummary(x.shape, x.dtype, x.mean()) + return x + + +@test_util.run_all_in_graph_and_eager_modes +class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): + + def assertAllCloseOrEqual(self, x, y, **kwargs): + if isinstance(x, np.ndarray) or isinstance(x, float): + return super(AttentionWrapperV2Test, self).assertAllClose( + x, y, atol=1e-3, **kwargs) + else: + self.assertAllEqual(x, y, **kwargs) + + def setUp(self): + super(AttentionWrapperV2Test, self).setUp() + self.batch = 64 + self.units = 128 + self.encoder_timestep = 10 + self.encoder_dim = 256 + self.decoder_timestep = 12 + self.encoder_outputs = np.random.randn(self.batch, self.encoder_timestep, + self.encoder_dim) + self.encoder_sequence_length = np.random.randint( + self.encoder_timestep, size=(self.batch,)).astype(np.int32) + self.decoder_inputs = np.random.randn(self.batch, self.decoder_timestep, + self.units) + self.decoder_sequence_length = np.random.randint( + self.decoder_timestep, size=(self.batch,)).astype(np.int32) + + def _testWithAttention(self, + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=3, + alignment_history=False, + expected_final_alignment_history=None, + attention_layer_size=6, + attention_layer=None, + create_query_layer=False, + create_memory_layer=True, + create_attention_kwargs=None): + attention_layer_sizes = ([attention_layer_size] + if attention_layer_size is not None else None) + attention_layers = ([attention_layer] + if attention_layer is not None else None) + self._testWithMaybeMultiAttention( + is_multi=False, + create_attention_mechanisms=[create_attention_mechanism], + expected_final_output=expected_final_output, + expected_final_state=expected_final_state, + attention_mechanism_depths=[attention_mechanism_depth], + alignment_history=alignment_history, + expected_final_alignment_history=expected_final_alignment_history, + attention_layer_sizes=attention_layer_sizes, + attention_layers=attention_layers, + create_query_layer=create_query_layer, + create_memory_layer=create_memory_layer, + create_attention_kwargs=create_attention_kwargs) + + def _testWithMaybeMultiAttention(self, + is_multi, + create_attention_mechanisms, + expected_final_output, + expected_final_state, + attention_mechanism_depths, + alignment_history=False, + expected_final_alignment_history=None, + attention_layer_sizes=None, + attention_layers=None, + create_query_layer=False, + create_memory_layer=True, + create_attention_kwargs=None): + # Allow is_multi to be True with a single mechanism to enable test for + # passing in a single mechanism in a list. + assert len(create_attention_mechanisms) == 1 or is_multi + encoder_sequence_length = [3, 2, 3, 1, 1] + decoder_sequence_length = [2, 0, 1, 2, 3] + batch_size = 5 + encoder_max_time = 8 + decoder_max_time = 4 + input_depth = 7 + encoder_output_depth = 10 + cell_depth = 9 + create_attention_kwargs = create_attention_kwargs or {} + + if attention_layer_sizes is not None: + # Compute sum of attention_layer_sizes. Use encoder_output_depth if None. + attention_depth = sum(attention_layer_size or encoder_output_depth + for attention_layer_size in attention_layer_sizes) + elif attention_layers is not None: + # Compute sum of attention_layers output depth. + attention_depth = sum( + attention_layer.compute_output_shape( + [batch_size, cell_depth + encoder_output_depth]).dims[-1].value + for attention_layer in attention_layers) + else: + attention_depth = encoder_output_depth * len(create_attention_mechanisms) + + decoder_inputs = np.random.randn(batch_size, decoder_max_time, + input_depth).astype(np.float32) + encoder_outputs = np.random.randn(batch_size, encoder_max_time, + encoder_output_depth).astype(np.float32) + + attention_mechanisms = [] + for creator, depth in zip(create_attention_mechanisms, + attention_mechanism_depths): + # Create a memory layer with deterministic initializer to avoid randomness + # in the test between graph and eager. + if create_query_layer: + create_attention_kwargs["query_layer"] = keras.layers.Dense( + depth, kernel_initializer="ones", use_bias=False) + if create_memory_layer: + create_attention_kwargs["memory_layer"] = keras.layers.Dense( + depth, kernel_initializer="ones", use_bias=False) + + attention_mechanisms.append( + creator( + units=depth, + memory=encoder_outputs, + memory_sequence_length=encoder_sequence_length, + **create_attention_kwargs)) + + with self.cached_session(use_gpu=True): + attention_layer_size = attention_layer_sizes + attention_layer = attention_layers + if not is_multi: + if attention_layer_size is not None: + attention_layer_size = attention_layer_size[0] + if attention_layer is not None: + attention_layer = attention_layer[0] + cell = rnn_cell.LSTMCell(cell_depth, initializer="ones") + cell = wrapper.AttentionWrapper( + cell, + attention_mechanisms if is_multi else attention_mechanisms[0], + attention_layer_size=attention_layer_size, + alignment_history=alignment_history, + attention_layer=attention_layer) + # Set the attention_layer within AttentionWrapper to have deterministic + # kernel initializer, for testing purpose. + if cell._attention_layers is not None: + for layer in cell._attention_layers: + if getattr(layer, "kernel_initializer") is None: + layer.kernel_initializer = initializers.ones() + + sampler = sampler_py.TrainingSampler() + my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + final_outputs, final_state, _ = my_decoder( + decoder_inputs, + initial_state=initial_state, + sequence_length=decoder_sequence_length) + + self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) + self.assertIsInstance(final_state, wrapper.AttentionWrapperState) + self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple) + + expected_time = ( + expected_final_state.time if context.executing_eagerly() else None) + self.assertEqual((batch_size, expected_time, attention_depth), + tuple(final_outputs.rnn_output.get_shape().as_list())) + self.assertEqual((batch_size, expected_time), + tuple(final_outputs.sample_id.get_shape().as_list())) + + self.assertEqual((batch_size, attention_depth), + tuple(final_state.attention.get_shape().as_list())) + self.assertEqual((batch_size, cell_depth), + tuple(final_state.cell_state.c.get_shape().as_list())) + self.assertEqual((batch_size, cell_depth), + tuple(final_state.cell_state.h.get_shape().as_list())) + + if alignment_history: + if is_multi: + state_alignment_history = [] + for history_array in final_state.alignment_history: + history = history_array.stack() + self.assertEqual((expected_time, batch_size, encoder_max_time), + tuple(history.get_shape().as_list())) + state_alignment_history.append(history) + state_alignment_history = tuple(state_alignment_history) + else: + state_alignment_history = final_state.alignment_history.stack() + self.assertEqual((expected_time, batch_size, encoder_max_time), + tuple(state_alignment_history.get_shape().as_list())) + nest.assert_same_structure(cell.state_size, + cell.zero_state(batch_size, dtypes.float32)) + # Remove the history from final_state for purposes of the + # remainder of the tests. + final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access + else: + state_alignment_history = () + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "final_outputs": final_outputs, + "final_state": final_state, + "state_alignment_history": state_alignment_history, + }) + + final_output_info = nest.map_structure(get_result_summary, + eval_result["final_outputs"]) + final_state_info = nest.map_structure(get_result_summary, + eval_result["final_state"]) + print("final_output_info: ", final_output_info) + print("final_state_info: ", final_state_info) + + nest.map_structure(self.assertAllCloseOrEqual, expected_final_output, + final_output_info) + nest.map_structure(self.assertAllCloseOrEqual, expected_final_state, + final_state_info) + if alignment_history: # by default, the wrapper emits attention as output + final_alignment_history_info = nest.map_structure( + get_result_summary, eval_result["state_alignment_history"]) + print("final_alignment_history_info: ", final_alignment_history_info) + nest.map_structure( + self.assertAllCloseOrEqual, + # outputs are batch major but the stacked TensorArray is time major + expected_final_alignment_history, + final_alignment_history_info) + + @parameterized.parameters([np.float16, np.float32, np.float64]) + def _testBahdanauNormalizedDType(self, dtype): + encoder_outputs = self.encoder_outputs.astype(dtype) + decoder_inputs = self.decoder_inputs.astype(dtype) + attention_mechanism = wrapper.BahdanauAttentionV2( + units=self.units, + memory=encoder_outputs, + memory_sequence_length=self.encoder_sequence_length, + normalize=True, + dtype=dtype) + cell = rnn_cell.LSTMCell(self.units) + cell = wrapper.AttentionWrapper(cell, attention_mechanism) + + sampler = sampler_py.TrainingSampler() + my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) + + final_outputs, final_state, _ = my_decoder( + decoder_inputs, + initial_state=cell.zero_state(dtype=dtype, batch_size=self.batch), + sequence_length=self.decoder_sequence_length) + self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) + self.assertEqual(final_outputs.rnn_output.dtype, dtype) + self.assertIsInstance(final_state, wrapper.AttentionWrapperState) + self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple) + + @parameterized.parameters([np.float16, np.float32, np.float64]) + def testLuongScaledDType(self, dtype): + # Test case for GitHub issue 18099 + encoder_outputs = self.encoder_outputs.astype(dtype) + decoder_inputs = self.decoder_inputs.astype(dtype) + attention_mechanism = wrapper.LuongAttentionV2( + units=self.units, + memory=encoder_outputs, + memory_sequence_length=self.encoder_sequence_length, + scale=True, + dtype=dtype, + ) + cell = rnn_cell.LSTMCell(self.units) + cell = wrapper.AttentionWrapper(cell, attention_mechanism) + + sampler = sampler_py.TrainingSampler() + my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) + + final_outputs, final_state, _ = my_decoder( + decoder_inputs, + initial_state=cell.zero_state(dtype=dtype, batch_size=self.batch), + sequence_length=self.decoder_sequence_length) + self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) + self.assertEqual(final_outputs.rnn_output.dtype, dtype) + self.assertIsInstance(final_state, wrapper.AttentionWrapperState) + self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple) + + def testBahdanauNotNormalized(self): + create_attention_mechanism = wrapper.BahdanauAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones"} + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=4.8290324), + sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype(np.int32), mean=0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype(np.float32), mean=6.7445569), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype(np.float32), mean=0.125) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + alignment_history=True, + create_query_layer=True, + expected_final_alignment_history=expected_final_alignment_history, + create_attention_kwargs=create_attention_kwargs) + + def testBahdanauNormalized(self): + create_attention_mechanism = wrapper.BahdanauAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.9548259), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=6.3075728), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + alignment_history=()) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + create_query_layer=True, + create_attention_kwargs=create_attention_kwargs) + + def testLuongNotNormalized(self): + create_attention_mechanism = wrapper.LuongAttentionV2 + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=2.6605489), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=4.084631), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + alignment_history=()) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=9) + + def testLuongScaled(self): + create_attention_mechanism = wrapper.LuongAttentionV2 + create_attention_kwargs = {"scale": True} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=2.6605489), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=4.0846314), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + alignment_history=()) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=9, + create_attention_kwargs=create_attention_kwargs) + + def testNotUseAttentionLayer(self): + create_attention_mechanism = wrapper.BahdanauAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones"} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 10), dtype=np.dtype("float32"), mean=0.072406612), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=3.86666666)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742)), + attention=ResultSummary( + shape=(5, 10), dtype=np.dtype("float32"), mean=0.011346335), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + alignment_history=()) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_layer_size=None, + create_query_layer=True, + create_attention_kwargs=create_attention_kwargs) + + def testBahdanauMonotonicNotNormalized(self): + create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones"} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=5.9850435), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=8.361186), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.10989678), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.10989678), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.117412611) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history, + create_query_layer=True, + create_attention_kwargs=create_attention_kwargs) + + def testBahdanauMonotonicNormalized(self): + create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones", + "normalize": True} + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=4.5706983), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=7.3326721), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.12258384), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.12258384), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12258384) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history, + create_query_layer=True, + create_attention_kwargs=create_attention_kwargs) + + def testLuongMonotonicNotNormalized(self): + create_attention_mechanism = wrapper.LuongMonotonicAttentionV2 + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.159497), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.11899644) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=9, + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history) + + def testLuongMonotonicScaled(self): + create_attention_mechanism = wrapper.LuongMonotonicAttentionV2 + create_attention_kwargs = {"scale": True} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.159497), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.11899644) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=9, + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history, + create_attention_kwargs=create_attention_kwargs) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py index b7f9f3fb090356a1c8d2bfb5044712ff93e267ce..abcf71c61b6e6df9462bf06323b8b11d5cc0d9a8 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py @@ -34,8 +34,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope -from tensorflow.python.ops.distributions import bernoulli -from tensorflow.python.ops.distributions import categorical from tensorflow.python.platform import test # pylint: enable=g-import-not-at-top @@ -517,7 +515,7 @@ class BasicDecoderTest(test.TestCase): vocabulary_size) # The sample function samples categorically from the logits. - sample_fn = lambda x: categorical.Categorical(logits=x).sample() + sample_fn = lambda x: helper_py.categorical_sample(logits=x) # The next inputs are a one-hot encoding of the sampled labels. next_inputs_fn = ( lambda x: array_ops.one_hot(x, vocabulary_size, dtype=dtypes.float32)) @@ -599,7 +597,7 @@ class BasicDecoderTest(test.TestCase): # The sample function samples independent bernoullis from the logits. sample_fn = ( - lambda x: bernoulli.Bernoulli(logits=x, dtype=dtypes.bool).sample()) + lambda x: helper_py.bernoulli_sample(logits=x, dtype=dtypes.bool)) # The next inputs are a one-hot encoding of the sampled labels. next_inputs_fn = math_ops.to_float end_fn = lambda sample_ids: sample_ids[:, end_token] diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2341ebb77ab6ecad1e979bc8bed0080128a804da --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_v2_test.py @@ -0,0 +1,670 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib.seq2seq.python.seq2seq.basic_decoder_v2.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.seq2seq.python.ops import basic_decoder +from tensorflow.contrib.seq2seq.python.ops import sampler as sampler_py + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.layers import core as layers_core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +@keras_parameterized.run_all_keras_modes +class BasicDecoderTest(keras_parameterized.TestCase): + """Unit test for basic_decoder.BasicDecoderV2.""" + + @parameterized.named_parameters( + ("use_output_layer", True), + ("without_output_layer", False)) + def testStepWithTrainingHelperOutputLayer(self, use_output_layer): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + cell_depth = 10 + output_layer_depth = 3 + + with self.cached_session(use_gpu=True): + inputs = np.random.randn(batch_size, max_time, + input_depth).astype(np.float32) + input_t = constant_op.constant(inputs) + cell = rnn_cell.LSTMCell(cell_depth) + sampler = sampler_py.TrainingSampler(time_major=False) + if use_output_layer: + output_layer = layers_core.Dense(output_layer_depth, use_bias=False) + expected_output_depth = output_layer_depth + else: + output_layer = None + expected_output_depth = cell_depth + initial_state = cell.zero_state(dtype=dtypes.float32, + batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler, + output_layer=output_layer) + + (first_finished, + first_inputs, + first_state) = my_decoder.initialize(input_t, + initial_state=initial_state, + sequence_length=sequence_length) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(expected_output_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, expected_output_depth), + step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + if use_output_layer: + # The output layer was accessed + self.assertEqual(len(output_layer.variables), 1) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + self.assertAllEqual([False, False, False, False, True], + eval_result["first_finished"]) + self.assertAllEqual([False, False, False, True, True], + eval_result["step_finished"]) + self.assertEqual(output_dtype.sample_id, + eval_result["step_outputs"].sample_id.dtype) + self.assertAllEqual( + np.argmax(eval_result["step_outputs"].rnn_output, -1), + eval_result["step_outputs"].sample_id) + + def DISABLED_testStepWithGreedyEmbeddingHelper(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size # cell's logits must match vocabulary size + input_depth = 10 + start_tokens = np.random.randint(0, vocabulary_size, size=batch_size) + end_token = 1 + + with self.cached_session(use_gpu=True): + embeddings = np.random.randn(vocabulary_size, + input_depth).astype(np.float32) + embeddings_t = constant_op.constant(embeddings) + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.GreedyEmbeddingSampler() + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + (first_finished, first_inputs, first_state) = my_decoder.initialize( + embeddings_t, + start_tokens=start_tokens, + end_token=end_token, + initial_state=initial_state) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + expected_sample_ids = np.argmax( + eval_result["step_outputs"].rnn_output, -1) + expected_step_finished = (expected_sample_ids == end_token) + expected_step_next_inputs = embeddings[expected_sample_ids] + self.assertAllEqual([False, False, False, False, False], + eval_result["first_finished"]) + self.assertAllEqual(expected_step_finished, eval_result["step_finished"]) + self.assertEqual(output_dtype.sample_id, + eval_result["step_outputs"].sample_id.dtype) + self.assertAllEqual(expected_sample_ids, + eval_result["step_outputs"].sample_id) + self.assertAllEqual(expected_step_next_inputs, + eval_result["step_next_inputs"]) + + def testStepWithSampleEmbeddingHelper(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size # cell's logits must match vocabulary size + input_depth = 10 + np.random.seed(0) + start_tokens = np.random.randint(0, vocabulary_size, size=batch_size) + end_token = 1 + + with self.cached_session(use_gpu=True): + embeddings = np.random.randn(vocabulary_size, + input_depth).astype(np.float32) + embeddings_t = constant_op.constant(embeddings) + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.SampleEmbeddingSampler(seed=0) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) + (first_finished, + first_inputs, + first_state) = my_decoder.initialize(embeddings_t, + start_tokens=start_tokens, + end_token=end_token, + initial_state=initial_state) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + expected_step_finished = (sample_ids == end_token) + expected_step_next_inputs = embeddings[sample_ids] + self.assertAllEqual(expected_step_finished, + eval_result["step_finished"]) + self.assertAllEqual(expected_step_next_inputs, + eval_result["step_next_inputs"]) + + def testStepWithScheduledEmbeddingTrainingHelper(self): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + vocabulary_size = 10 + + with self.cached_session(use_gpu=True): + inputs = np.random.randn( + batch_size, max_time, input_depth).astype(np.float32) + input_t = constant_op.constant(inputs) + embeddings = np.random.randn( + vocabulary_size, input_depth).astype(np.float32) + half = constant_op.constant(0.5) + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.ScheduledEmbeddingTrainingSampler( + sampling_probability=half, + time_major=False) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + (first_finished, first_inputs, first_state) = my_decoder.initialize( + input_t, sequence_length=sequence_length, embedding=embeddings, + initial_state=initial_state) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(vocabulary_size, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, vocabulary_size), + step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + first_state[0].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + first_state[1].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + step_state[0].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + step_state[1].get_shape()) + self.assertEqual((batch_size, input_depth), + step_next_inputs.get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + self.assertAllEqual([False, False, False, False, True], + eval_result["first_finished"]) + self.assertAllEqual([False, False, False, True, True], + eval_result["step_finished"]) + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + batch_where_not_sampling = np.where(sample_ids == -1) + batch_where_sampling = np.where(sample_ids > -1) + self.assertAllClose( + eval_result["step_next_inputs"][batch_where_sampling], + embeddings[sample_ids[batch_where_sampling]]) + self.assertAllClose( + eval_result["step_next_inputs"][batch_where_not_sampling], + np.squeeze(inputs[batch_where_not_sampling, 1], axis=0)) + + def _testStepWithScheduledOutputTrainingHelper( + self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + cell_depth = input_depth + if use_auxiliary_inputs: + auxiliary_input_depth = 4 + auxiliary_inputs = np.random.randn( + batch_size, max_time, auxiliary_input_depth).astype(np.float32) + else: + auxiliary_inputs = None + + with self.cached_session(use_gpu=True): + inputs = np.random.randn(batch_size, max_time, + input_depth).astype(np.float32) + input_t = constant_op.constant(inputs) + cell = rnn_cell.LSTMCell(cell_depth) + sampling_probability = constant_op.constant(sampling_probability) + + if use_next_inputs_fn: + def next_inputs_fn(outputs): + # Use deterministic function for test. + samples = math_ops.argmax(outputs, axis=1) + return array_ops.one_hot(samples, cell_depth, dtype=dtypes.float32) + else: + next_inputs_fn = None + + sampler = sampler_py.ScheduledOutputTrainingSampler( + sampling_probability=sampling_probability, + time_major=False, + next_inputs_fn=next_inputs_fn) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + + (first_finished, + first_inputs, + first_state) = my_decoder.initialize(input_t, + sequence_length=sequence_length, + initial_state=initial_state, + auxiliary_inputs=auxiliary_inputs) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + + if use_next_inputs_fn: + output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output) + + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + + fetches = { + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + } + if use_next_inputs_fn: + fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn + + eval_result = self.evaluate(fetches) + + self.assertAllEqual([False, False, False, False, True], + eval_result["first_finished"]) + self.assertAllEqual([False, False, False, True, True], + eval_result["step_finished"]) + + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + batch_where_not_sampling = np.where(np.logical_not(sample_ids)) + batch_where_sampling = np.where(sample_ids) + + auxiliary_inputs_to_concat = ( + auxiliary_inputs[:, 1] if use_auxiliary_inputs else + np.array([]).reshape(batch_size, 0).astype(np.float32)) + + expected_next_sampling_inputs = np.concatenate( + (eval_result["output_after_next_inputs_fn"][batch_where_sampling] + if use_next_inputs_fn else + eval_result["step_outputs"].rnn_output[batch_where_sampling], + auxiliary_inputs_to_concat[batch_where_sampling]), + axis=-1) + self.assertAllClose( + eval_result["step_next_inputs"][batch_where_sampling], + expected_next_sampling_inputs) + + self.assertAllClose( + eval_result["step_next_inputs"][batch_where_not_sampling], + np.concatenate( + (np.squeeze(inputs[batch_where_not_sampling, 1], axis=0), + auxiliary_inputs_to_concat[batch_where_not_sampling]), + axis=-1)) + + def testStepWithScheduledOutputTrainingHelperWithoutNextInputsFnOrAuxInputs( + self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.5, use_next_inputs_fn=False, + use_auxiliary_inputs=False) + + def testStepWithScheduledOutputTrainingHelperWithNextInputsFn(self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.5, use_next_inputs_fn=True, + use_auxiliary_inputs=False) + + def testStepWithScheduledOutputTrainingHelperWithAuxiliaryInputs(self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.5, use_next_inputs_fn=False, + use_auxiliary_inputs=True) + + def testStepWithScheduledOutputTrainingHelperWithNextInputsFnAndAuxInputs( + self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.5, use_next_inputs_fn=True, + use_auxiliary_inputs=True) + + def testStepWithScheduledOutputTrainingHelperWithNoSampling(self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.0, use_next_inputs_fn=True, + use_auxiliary_inputs=True) + + def testStepWithInferenceHelperCategorical(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size + start_token = 0 + end_token = 6 + + start_inputs = array_ops.one_hot( + np.ones(batch_size, dtype=np.int32) * start_token, + vocabulary_size) + + # The sample function samples categorically from the logits. + sample_fn = lambda x: sampler_py.categorical_sample(logits=x) + # The next inputs are a one-hot encoding of the sampled labels. + next_inputs_fn = ( + lambda x: array_ops.one_hot(x, vocabulary_size, dtype=dtypes.float32)) + end_fn = lambda sample_ids: math_ops.equal(sample_ids, end_token) + + with self.cached_session(use_gpu=True): + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.InferenceSampler( + sample_fn, sample_shape=(), sample_dtype=dtypes.int32, end_fn=end_fn, + next_inputs_fn=next_inputs_fn) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + (first_finished, first_inputs, first_state) = my_decoder.initialize( + start_inputs, initial_state=initial_state) + + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + expected_step_finished = (sample_ids == end_token) + expected_step_next_inputs = np.zeros((batch_size, vocabulary_size)) + expected_step_next_inputs[np.arange(batch_size), sample_ids] = 1.0 + self.assertAllEqual(expected_step_finished, + eval_result["step_finished"]) + self.assertAllEqual(expected_step_next_inputs, + eval_result["step_next_inputs"]) + + def testStepWithInferenceHelperMultilabel(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size + start_token = 0 + end_token = 6 + + start_inputs = array_ops.one_hot( + np.ones(batch_size, dtype=np.int32) * start_token, + vocabulary_size) + + # The sample function samples independent bernoullis from the logits. + sample_fn = ( + lambda x: sampler_py.bernoulli_sample(logits=x, dtype=dtypes.bool)) + # The next inputs are a one-hot encoding of the sampled labels. + next_inputs_fn = math_ops.to_float + end_fn = lambda sample_ids: sample_ids[:, end_token] + + with self.cached_session(use_gpu=True): + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.InferenceSampler( + sample_fn, sample_shape=[cell_depth], sample_dtype=dtypes.bool, + end_fn=end_fn, next_inputs_fn=next_inputs_fn) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + (first_finished, first_inputs, first_state) = my_decoder.initialize( + start_inputs, initial_state=initial_state) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, cell_depth), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.bool), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + expected_step_finished = sample_ids[:, end_token] + expected_step_next_inputs = sample_ids.astype(np.float32) + self.assertAllEqual(expected_step_finished, + eval_result["step_finished"]) + self.assertAllEqual(expected_step_next_inputs, + eval_result["step_next_inputs"]) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 5e28e651c666b1c448f778fc9c02d637ce817bae..56f2a0acc9f2e6f951c5df26a53a31645697da4f 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -25,10 +25,13 @@ from tensorflow.contrib.seq2seq.python.ops import attention_wrapper from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras import layers from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops @@ -530,11 +533,10 @@ class BeamSearchDecoderTest(test.TestCase): return (shape[1], shape[0]) + shape[2:] return shape - self.assertTrue( - isinstance(final_outputs, - beam_search_decoder.FinalBeamSearchDecoderOutput)) - self.assertTrue( - isinstance(final_state, beam_search_decoder.BeamSearchDecoderState)) + self.assertIsInstance( + final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) + self.assertIsInstance( + final_state, beam_search_decoder.BeamSearchDecoderState) beam_search_decoder_output = final_outputs.beam_search_decoder_output self.assertEqual( @@ -574,5 +576,119 @@ class BeamSearchDecoderTest(test.TestCase): with_alignment_history=True) +@test_util.run_all_in_graph_and_eager_modes +class BeamSearchDecoderV2Test(test.TestCase): + + def _testDynamicDecodeRNN(self, time_major, has_attention, + with_alignment_history=False): + encoder_sequence_length = np.array([3, 2, 3, 1, 1]) + decoder_sequence_length = np.array([2, 0, 1, 2, 3]) + batch_size = 5 + decoder_max_time = 4 + input_depth = 7 + cell_depth = 9 + attention_depth = 6 + vocab_size = 20 + end_token = vocab_size - 1 + start_token = 0 + embedding_dim = 50 + max_out = max(decoder_sequence_length) + output_layer = layers.Dense(vocab_size, use_bias=True, activation=None) + beam_width = 3 + + with self.cached_session(): + batch_size_tensor = constant_op.constant(batch_size) + embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) + cell = rnn_cell.LSTMCell(cell_depth) + initial_state = cell.zero_state(batch_size, dtypes.float32) + coverage_penalty_weight = 0.0 + if has_attention: + coverage_penalty_weight = 0.2 + inputs = array_ops.placeholder_with_default( + np.random.randn(batch_size, decoder_max_time, input_depth).astype( + np.float32), + shape=(None, None, input_depth)) + tiled_inputs = beam_search_decoder.tile_batch( + inputs, multiplier=beam_width) + tiled_sequence_length = beam_search_decoder.tile_batch( + encoder_sequence_length, multiplier=beam_width) + attention_mechanism = attention_wrapper.BahdanauAttention( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + initial_state = beam_search_decoder.tile_batch( + initial_state, multiplier=beam_width) + cell = attention_wrapper.AttentionWrapper( + cell=cell, + attention_mechanism=attention_mechanism, + attention_layer_size=attention_depth, + alignment_history=with_alignment_history) + cell_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) + if has_attention: + cell_state = cell_state.clone(cell_state=initial_state) + bsd = beam_search_decoder.BeamSearchDecoderV2( + cell=cell, + beam_width=beam_width, + output_layer=output_layer, + length_penalty_weight=0.0, + coverage_penalty_weight=coverage_penalty_weight, + output_time_major=time_major, + maximum_iterations=max_out) + + final_outputs, final_state, final_sequence_lengths = bsd( + embedding, + start_tokens=array_ops.fill([batch_size_tensor], start_token), + end_token=end_token, + initial_state=cell_state) + + def _t(shape): + if time_major: + return (shape[1], shape[0]) + shape[2:] + return shape + + self.assertIsInstance( + final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) + self.assertIsInstance( + final_state, beam_search_decoder.BeamSearchDecoderState) + + beam_search_decoder_output = final_outputs.beam_search_decoder_output + expected_seq_length = 3 if context.executing_eagerly() else None + self.assertEqual( + _t((batch_size, expected_seq_length, beam_width)), + tuple(beam_search_decoder_output.scores.get_shape().as_list())) + self.assertEqual( + _t((batch_size, expected_seq_length, beam_width)), + tuple(final_outputs.predicted_ids.get_shape().as_list())) + + self.evaluate(variables.global_variables_initializer()) + eval_results = self.evaluate({ + 'final_outputs': final_outputs, + 'final_sequence_lengths': final_sequence_lengths + }) + + max_sequence_length = np.max(eval_results['final_sequence_lengths']) + + # A smoke test + self.assertEqual( + _t((batch_size, max_sequence_length, beam_width)), + eval_results['final_outputs'].beam_search_decoder_output.scores.shape) + self.assertEqual( + _t((batch_size, max_sequence_length, beam_width)), eval_results[ + 'final_outputs'].beam_search_decoder_output.predicted_ids.shape) + + def testDynamicDecodeRNNBatchMajorNoAttention(self): + self._testDynamicDecodeRNN(time_major=False, has_attention=False) + + def testDynamicDecodeRNNBatchMajorYesAttention(self): + self._testDynamicDecodeRNN(time_major=False, has_attention=True) + + def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self): + self._testDynamicDecodeRNN( + time_major=False, + has_attention=True, + with_alignment_history=True) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b5bba2b32e940aa4d5984821ebd3845d7f272549 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_v2_test.py @@ -0,0 +1,169 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib.seq2seq.python.seq2seq.decoder.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.seq2seq.python.ops import basic_decoder +from tensorflow.contrib.seq2seq.python.ops import sampler as sampler_py +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +@keras_parameterized.run_all_keras_modes +class DecodeV2RNNTest(keras_parameterized.TestCase, test.TestCase): + """Tests for DecoderV2.""" + + def _testDecodeRNN(self, time_major, maximum_iterations=None): + + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + cell_depth = 10 + max_out = max(sequence_length) + + with self.cached_session(use_gpu=True): + if time_major: + inputs = np.random.randn(max_time, batch_size, + input_depth).astype(np.float32) + else: + inputs = np.random.randn(batch_size, max_time, + input_depth).astype(np.float32) + input_t = constant_op.constant(inputs) + cell = rnn_cell.LSTMCell(cell_depth) + sampler = sampler_py.TrainingSampler(time_major=time_major) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler, + output_time_major=time_major, + maximum_iterations=maximum_iterations) + + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + (final_outputs, unused_final_state, final_sequence_length) = my_decoder( + input_t, initial_state=initial_state, sequence_length=sequence_length) + + def _t(shape): + if time_major: + return (shape[1], shape[0]) + shape[2:] + return shape + + if not context.executing_eagerly(): + self.assertEqual((batch_size,), + tuple(final_sequence_length.get_shape().as_list())) + self.assertEqual( + _t((batch_size, None, cell_depth)), + tuple(final_outputs.rnn_output.get_shape().as_list())) + self.assertEqual( + _t((batch_size, None)), + tuple(final_outputs.sample_id.get_shape().as_list())) + + self.evaluate(variables.global_variables_initializer()) + final_outputs = self.evaluate(final_outputs) + final_sequence_length = self.evaluate(final_sequence_length) + + # Mostly a smoke test + time_steps = max_out + expected_length = sequence_length + if maximum_iterations is not None: + time_steps = min(max_out, maximum_iterations) + expected_length = [min(x, maximum_iterations) for x in expected_length] + if context.executing_eagerly() and maximum_iterations != 0: + # Only check the shape of output when maximum_iterations > 0, see + # b/123431432 for more details. + self.assertEqual( + _t((batch_size, time_steps, cell_depth)), + final_outputs.rnn_output.shape) + self.assertEqual( + _t((batch_size, time_steps)), final_outputs.sample_id.shape) + self.assertItemsEqual(expected_length, final_sequence_length) + + def testDynamicDecodeRNNBatchMajor(self): + self._testDecodeRNN(time_major=False) + + def testDynamicDecodeRNNTimeMajor(self): + self._testDecodeRNN(time_major=True) + + def testDynamicDecodeRNNZeroMaxIters(self): + self._testDecodeRNN(time_major=True, maximum_iterations=0) + + def testDynamicDecodeRNNOneMaxIter(self): + self._testDecodeRNN(time_major=True, maximum_iterations=1) + + def _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( + self, use_sequence_length): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + cell_depth = 10 + max_out = max(sequence_length) + + with self.cached_session(use_gpu=True): + inputs = np.random.randn(batch_size, max_time, + input_depth).astype(np.float32) + inputs = constant_op.constant(inputs) + + cell = rnn_cell.LSTMCell(cell_depth) + zero_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size) + sampler = sampler_py.TrainingSampler() + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, sampler=sampler, impute_finished=use_sequence_length) + + final_decoder_outputs, final_decoder_state, _ = my_decoder( + inputs, initial_state=zero_state, sequence_length=sequence_length) + + final_rnn_outputs, final_rnn_state = rnn.dynamic_rnn( + cell, + inputs, + sequence_length=sequence_length if use_sequence_length else None, + initial_state=zero_state) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "final_decoder_outputs": final_decoder_outputs, + "final_decoder_state": final_decoder_state, + "final_rnn_outputs": final_rnn_outputs, + "final_rnn_state": final_rnn_state + }) + + # Decoder only runs out to max_out; ensure values are identical + # to dynamic_rnn, which also zeros out outputs and passes along state. + self.assertAllClose(eval_result["final_decoder_outputs"].rnn_output, + eval_result["final_rnn_outputs"][:, 0:max_out, :]) + if use_sequence_length: + self.assertAllClose(eval_result["final_decoder_state"], + eval_result["final_rnn_state"]) + + def testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNNWithSeqLen(self): + self._testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( + use_sequence_length=True) + + def testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNNNoSeqLen(self): + self._testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( + use_sequence_length=False) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py index 5aa32b532ffcf5772f6ace26662f5e5471cf6923..41b2a53ca5b178be9b04446c81d832575e5ed75b 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py @@ -14,80 +14,254 @@ # ============================================================================== """Tests for contrib.seq2seq.python.seq2seq.loss_ops.""" -# pylint: disable=unused-import,g-bad-import-order from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: enable=unused-import import numpy as np from tensorflow.contrib.seq2seq.python.ops import loss from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +@test_util.run_all_in_graph_and_eager_modes class LossTest(test.TestCase): + def setUp(self): + self.batch_size = 2 + self.sequence_length = 3 + self.number_of_classes = 5 + logits = [ + constant_op.constant(i + 0.5, shape=[self.batch_size, + self.number_of_classes]) + for i in range(self.sequence_length) + ] + self.logits = array_ops.stack(logits, axis=1) + targets = [ + constant_op.constant(i, dtypes.int32, shape=[self.batch_size]) + for i in range(self.sequence_length) + ] + self.targets = array_ops.stack(targets, axis=1) + weights = [ + constant_op.constant(1.0, shape=[self.batch_size]) + for _ in range(self.sequence_length) + ] + self.weights = array_ops.stack(weights, axis=1) + # expected_loss = sparse_softmax_cross_entropy_with_logits(targets, logits) + # where targets = [0, 1, 2], and logits = [[0.5] * 5, [1.5] * 5, [2.5] * 5] + self.expected_loss = 1.60944 + def testSequenceLoss(self): - with self.session(use_gpu=True) as sess: - with variable_scope.variable_scope( - 'root', initializer=init_ops.constant_initializer(0.5)): - batch_size = 2 - sequence_length = 3 - number_of_classes = 5 - logits = [ - constant_op.constant( - i + 0.5, shape=[batch_size, number_of_classes]) - for i in range(sequence_length) - ] - logits = array_ops.stack(logits, axis=1) - targets = [ - constant_op.constant( - i, dtypes.int32, shape=[batch_size]) - for i in range(sequence_length) - ] - targets = array_ops.stack(targets, axis=1) - weights = [ - constant_op.constant( - 1.0, shape=[batch_size]) for i in range(sequence_length) - ] - weights = array_ops.stack(weights, axis=1) - - average_loss_per_example = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=True, - average_across_batch=True) - res = sess.run(average_loss_per_example) - self.assertAllClose(1.60944, res) - - average_loss_per_sequence = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=False, - average_across_batch=True) - res = sess.run(average_loss_per_sequence) - compare_per_sequence = np.ones((sequence_length)) * 1.60944 - self.assertAllClose(compare_per_sequence, res) - - average_loss_per_batch = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=True, - average_across_batch=False) - res = sess.run(average_loss_per_batch) - compare_per_batch = np.ones((batch_size)) * 1.60944 - self.assertAllClose(compare_per_batch, res) - - total_loss = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=False, - average_across_batch=False) - res = sess.run(total_loss) - compare_total = np.ones((batch_size, sequence_length)) * 1.60944 - self.assertAllClose(compare_total, res) + with self.test_session(use_gpu=True): + average_loss_per_example = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=True, + average_across_batch=True) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + average_loss_per_sequence = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=False, + average_across_batch=True) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + self.assertAllClose(compare_per_sequence, res) + + average_loss_per_batch = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=True, + average_across_batch=False) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + total_loss = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=False, + average_across_batch=False) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + self.assertAllClose(compare_total, res) + + def testSequenceLossClass(self): + with self.test_session(use_gpu=True): + seq_loss = loss.SequenceLoss(average_across_timesteps=True, + average_across_batch=True, + sum_over_timesteps=False, + sum_over_batch=False) + average_loss_per_example = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=True, + sum_over_timesteps=False, + sum_over_batch=False) + average_loss_per_sequence = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + self.assertAllClose(compare_per_sequence, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=True, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + average_loss_per_batch = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + total_loss = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + self.assertAllClose(compare_total, res) + + def testSumReduction(self): + with self.test_session(use_gpu=True): + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=True) + average_loss_per_example = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=True) + average_loss_per_sequence = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + self.assertAllClose(compare_per_sequence, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=False) + average_loss_per_batch = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + total_loss = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + self.assertAllClose(compare_total, res) + + def testWeightedSumReduction(self): + weights = [ + constant_op.constant(1.0, shape=[self.batch_size]) + for _ in range(self.sequence_length) + ] + # Make the last element in the sequence to have zero weights. + weights[-1] = constant_op.constant(0.0, shape=[self.batch_size]) + self.weights = array_ops.stack(weights, axis=1) + with self.test_session(use_gpu=True): + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=True) + average_loss_per_example = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=True) + average_loss_per_sequence = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + # The last element in every sequence are zeros, which will be filtered. + compare_per_sequence[-1] = 0. + self.assertAllClose(compare_per_sequence, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=False) + average_loss_per_batch = seq_loss(self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + total_loss = seq_loss(self.targets, self.logits, self.weights) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + # The last element in every sequence are zeros, which will be filtered. + compare_total[:, -1] = 0 + self.assertAllClose(compare_total, res) + + def testZeroWeights(self): + weights = [ + constant_op.constant(0.0, shape=[self.batch_size]) + for _ in range(self.sequence_length) + ] + weights = array_ops.stack(weights, axis=1) + with self.test_session(use_gpu=True): + average_loss_per_example = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=True, + average_across_batch=True) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(0.0, res) + + average_loss_per_sequence = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=False, + average_across_batch=True) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.zeros((self.sequence_length)) + self.assertAllClose(compare_per_sequence, res) + + average_loss_per_batch = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=True, + average_across_batch=False) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.zeros((self.batch_size)) + self.assertAllClose(compare_per_batch, res) + + total_loss = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=False, + average_across_batch=False) + res = self.evaluate(total_loss) + compare_total = np.zeros((self.batch_size, self.sequence_length)) + self.assertAllClose(compare_total, res) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 77e9f848b137911b53e1b4df5dd740fe38af55bb..79c2ac2f500307ba23b6d97a7a30c6d04cea5176 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -25,9 +25,13 @@ import math import numpy as np from tensorflow.contrib.framework.python.framework import tensor_util +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import initializers +from tensorflow.python.keras import layers +from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.layers import base as layers_base from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops @@ -72,77 +76,6 @@ class AttentionMechanism(object): raise NotImplementedError -def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined): - """Convert to tensor and possibly mask `memory`. - - Args: - memory: `Tensor`, shaped `[batch_size, max_time, ...]`. - memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`. - check_inner_dims_defined: Python boolean. If `True`, the `memory` - argument's shape is checked to ensure all but the two outermost - dimensions are fully defined. - - Returns: - A (possibly masked), checked, new `memory`. - - Raises: - ValueError: If `check_inner_dims_defined` is `True` and not - `memory.shape[2:].is_fully_defined()`. - """ - memory = nest.map_structure( - lambda m: ops.convert_to_tensor(m, name="memory"), memory) - if memory_sequence_length is not None: - memory_sequence_length = ops.convert_to_tensor( - memory_sequence_length, name="memory_sequence_length") - if check_inner_dims_defined: - def _check_dims(m): - if not m.get_shape()[2:].is_fully_defined(): - raise ValueError("Expected memory %s to have fully defined inner dims, " - "but saw shape: %s" % (m.name, m.get_shape())) - nest.map_structure(_check_dims, memory) - if memory_sequence_length is None: - seq_len_mask = None - else: - seq_len_mask = array_ops.sequence_mask( - memory_sequence_length, - maxlen=array_ops.shape(nest.flatten(memory)[0])[1], - dtype=nest.flatten(memory)[0].dtype) - seq_len_batch_size = ( - tensor_shape.dimension_value(memory_sequence_length.shape[0]) - or array_ops.shape(memory_sequence_length)[0]) - def _maybe_mask(m, seq_len_mask): - rank = m.get_shape().ndims - rank = rank if rank is not None else array_ops.rank(m) - extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) - m_batch_size = tensor_shape.dimension_value( - m.shape[0]) or array_ops.shape(m)[0] - if memory_sequence_length is not None: - message = ("memory_sequence_length and memory tensor batch sizes do not " - "match.") - with ops.control_dependencies([ - check_ops.assert_equal( - seq_len_batch_size, m_batch_size, message=message)]): - seq_len_mask = array_ops.reshape( - seq_len_mask, - array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) - return m * seq_len_mask - else: - return m - return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) - - -def _maybe_mask_score(score, memory_sequence_length, score_mask_value): - if memory_sequence_length is None: - return score - message = ("All values in memory_sequence_length must greater than zero.") - with ops.control_dependencies( - [check_ops.assert_positive(memory_sequence_length, message=message)]): - score_mask = array_ops.sequence_mask( - memory_sequence_length, maxlen=array_ops.shape(score)[1]) - score_mask_values = score_mask_value * array_ops.ones_like(score) - return array_ops.where(score_mask, score, score_mask_values) - - class _BaseAttentionMechanism(AttentionMechanism): """A base AttentionMechanism class providing common functionality. @@ -205,12 +138,14 @@ class _BaseAttentionMechanism(AttentionMechanism): self._memory_layer.dtype).as_numpy_dtype(-np.inf) self._probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda probability_fn( - _maybe_mask_score(score, memory_sequence_length, score_mask_value), + _maybe_mask_score(score, + memory_sequence_length=memory_sequence_length, + score_mask_value=score_mask_value), prev)) with ops.name_scope( name, "BaseAttentionMechanismInit", nest.flatten(memory)): self._values = _prepare_memory( - memory, memory_sequence_length, + memory, memory_sequence_length=memory_sequence_length, check_inner_dims_defined=check_inner_dims_defined) self._keys = ( self.memory_layer(self._values) if self.memory_layer # pylint: disable=not-callable @@ -286,6 +221,376 @@ class _BaseAttentionMechanism(AttentionMechanism): return self.initial_alignments(batch_size, dtype) +class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): + """A base AttentionMechanism class providing common functionality. + + Common functionality includes: + 1. Storing the query and memory layers. + 2. Preprocessing and storing the memory. + + Note that this layer takes memory as its init parameter, which is an + anti-pattern of Keras API, we have to keep the memory as init parameter for + performance and dependency reason. Under the hood, during `__init__()`, it + will invoke `base_layer.__call__(memory, setup_memory=True)`. This will let + keras to keep track of the memory tensor as the input of this layer. Once + the `__init__()` is done, then user can query the attention by + `score = att_obj([query, state])`, and use it as a normal keras layer. + + Special attention is needed when adding using this class as the base layer for + new attention: + 1. Build() could be invoked at least twice. So please make sure weights are + not duplicated. + 2. Layer.get_weights() might return different set of weights if the instance + has `query_layer`. The query_layer weights is not initialized until the + memory is configured. + + Also note that this layer does not work with Keras model when + `model.compile(run_eagerly=True)` due to the fact that this layer is stateful. + The support for that will be added in a future version. + """ + + def __init__(self, + memory, + probability_fn, + query_layer=None, + memory_layer=None, + memory_sequence_length=None, + **kwargs): + """Construct base AttentionMechanism class. + + Args: + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + probability_fn: A `callable`. Converts the score and previous alignments + to probabilities. Its signature should be: + `probabilities = probability_fn(score, state)`. + query_layer: (optional): Instance of `tf.keras.Layer`. The layer's depth + must match the depth of `memory_layer`. If `query_layer` is not + provided, the shape of `query` must match that of `memory_layer`. + memory_layer: (optional): Instance of `tf.keras.Layer`. The layer's + depth must match the depth of `query_layer`. + If `memory_layer` is not provided, the shape of `memory` must match + that of `query_layer`. + memory_sequence_length (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + if (query_layer is not None + and not isinstance(query_layer, layers.Layer)): + raise TypeError( + "query_layer is not a Layer: %s" % type(query_layer).__name__) + if (memory_layer is not None + and not isinstance(memory_layer, layers.Layer)): + raise TypeError( + "memory_layer is not a Layer: %s" % type(memory_layer).__name__) + self.query_layer = query_layer + self.memory_layer = memory_layer + if self.memory_layer is not None and "dtype" not in kwargs: + kwargs["dtype"] = self.memory_layer.dtype + super(_BaseAttentionMechanismV2, self).__init__(**kwargs) + if not callable(probability_fn): + raise TypeError("probability_fn must be callable, saw type: %s" % + type(probability_fn).__name__) + self.probability_fn = probability_fn + + self.keys = None + self.values = None + self.batch_size = None + self._memory_initialized = False + self._check_inner_dims_defined = True + self.supports_masking = True + self.score_mask_value = dtypes.as_dtype(self.dtype).as_numpy_dtype(-np.inf) + + if memory is not None: + # Setup the memory by self.__call__() with memory and memory_seq_length. + # This will make the attention follow the keras convention which takes + # all the tensor inputs via __call__(). + if memory_sequence_length is None: + inputs = memory + else: + inputs = [memory, memory_sequence_length] + + self.values = super(_BaseAttentionMechanismV2, self).__call__( + inputs, setup_memory=True) + + def build(self, input_shape): + if not self._memory_initialized: + # This is for setting up the memory, which contains memory and optional + # memory_sequence_length. Build the memory_layer with memory shape. + if self.memory_layer is not None and not self.memory_layer.built: + if isinstance(input_shape, list): + self.memory_layer.build(input_shape[0]) + else: + self.memory_layer.build(input_shape) + else: + # The input_shape should be query.shape and state.shape. Use the query + # to init the query layer. + if self.query_layer is not None and not self.query_layer.built: + self.query_layer.build(input_shape[0]) + + def __call__(self, inputs, **kwargs): + """Preprocess the inputs before calling `base_layer.__call__()`. + + Note that there are situation here, one for setup memory, and one with + actual query and state. + 1. When the memory has not been configured, we just pass all the param to + base_layer.__call__(), which will then invoke self.call() with proper + inputs, which allows this class to setup memory. + 2. When the memory has already been setup, the input should contain query + and state, and optionally processed memory. If the processed memory is + not included in the input, we will have to append it to the inputs and + give it to the base_layer.__call__(). The processed memory is the output + of first invocation of self.__call__(). If we don't add it here, then from + keras perspective, the graph is disconnected since the output from + previous call is never used. + + Args: + inputs: the inputs tensors. + **kwargs: dict, other keyeword arguments for the `__call__()` + """ + if self._memory_initialized: + if len(inputs) not in (2, 3): + raise ValueError("Expect the inputs to have 2 or 3 tensors, got %d" % + len(inputs)) + if len(inputs) == 2: + # We append the calculated memory here so that the graph will be + # connected. + inputs.append(self.values) + return super(_BaseAttentionMechanismV2, self).__call__(inputs, **kwargs) + + def call(self, inputs, mask=None, setup_memory=False, **kwargs): + """Setup the memory or query the attention. + + There are two case here, one for setup memory, and the second is query the + attention score. `setup_memory` is the flag to indicate which mode it is. + The input list will be treated differently based on that flag. + + Args: + inputs: a list of tensor that could either be `query` and `state`, or + `memory` and `memory_sequence_length`. + `query` is the tensor of dtype matching `memory` and shape + `[batch_size, query_depth]`. + `state` is the tensor of dtype matching `memory` and shape + `[batch_size, alignments_size]`. (`alignments_size` is memory's + `max_time`). + `memory` is the memory to query; usually the output of an RNN encoder. + The tensor should be shaped `[batch_size, max_time, ...]`. + `memory_sequence_length` (optional) is the sequence lengths for the + batch entries in memory. If provided, the memory tensor rows are masked + with zeros for values past the respective sequence lengths. + mask: optional bool tensor with shape `[batch, max_time]` for the mask of + memory. If it is not None, the corresponding item of the memory should + be filtered out during calculation. + setup_memory: boolean, whether the input is for setting up memory, or + query attention. + **kwargs: Dict, other keyword arguments for the call method. + Returns: + Either processed memory or attention score, based on `setup_memory`. + """ + if setup_memory: + if isinstance(inputs, list): + if len(inputs) not in (1, 2): + raise ValueError("Expect inputs to have 1 or 2 tensors, got %d" % + len(inputs)) + memory = inputs[0] + memory_sequence_length = inputs[1] if len(inputs) == 2 else None + memory_mask = mask + else: + memory, memory_sequence_length = inputs, None + memory_mask = mask + self._setup_memory(memory, memory_sequence_length, memory_mask) + # We force the self.built to false here since only memory is initialized, + # but the real query/state has not been call() yet. The layer should be + # build and call again. + self.built = False + # Return the processed memory in order to create the Keras connectivity + # data for it. + return self.values + else: + if not self._memory_initialized: + raise ValueError("Cannot query the attention before the setup of " + "memory") + if len(inputs) not in (2, 3): + raise ValueError("Expect the inputs to have query, state, and optional " + "processed memory, got %d items" % len(inputs)) + # Ignore the rest of the inputs and only care about the query and state + query, state = inputs[0], inputs[1] + return self._calculate_attention(query, state) + + def _setup_memory(self, memory, memory_sequence_length=None, + memory_mask=None): + """Pre-process the memory before actually query the memory. + + This should only be called once at the first invocation of call(). + + Args: + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros for + values past the respective sequence lengths. + memory_mask: (Optional) The boolean tensor with shape `[batch_size, + max_time]`. For any value equal to False, the corresponding value in + memory should be ignored. + """ + if self._memory_initialized: + raise ValueError("The memory for the attention has already been setup.") + if memory_sequence_length is not None and memory_mask is not None: + raise ValueError("memory_sequence_length and memory_mask cannot be " + "used at same time for attention.") + with ops.name_scope( + self.name, "BaseAttentionMechanismInit", nest.flatten(memory)): + self.values = _prepare_memory( + memory, + memory_sequence_length=memory_sequence_length, + memory_mask=memory_mask, + check_inner_dims_defined=self._check_inner_dims_defined) + # Mark the value as check since the memory and memory mask might not + # passed from __call__(), which does not have proper keras metadata. + # TODO(omalleyt): Remove this hack once the mask the has proper keras + # history. + base_layer_utils.mark_checked(self.values) + if self.memory_layer is not None: + self.keys = self.memory_layer(self.values) + else: + self.keys = self.values + self.batch_size = ( + tensor_shape.dimension_value(self.keys.shape[0]) or + array_ops.shape(self.keys)[0]) + self._alignments_size = (tensor_shape.dimension_value(self.keys.shape[1]) + or array_ops.shape(self.keys)[1]) + if memory_mask is not None: + unwrapped_probability_fn = self.probability_fn + def _mask_probability_fn(score, prev): + return unwrapped_probability_fn( + _maybe_mask_score( + score, + memory_mask=memory_mask, + memory_sequence_length=memory_sequence_length, + score_mask_value=self.score_mask_value), prev) + self.probability_fn = _mask_probability_fn + self._memory_initialized = True + + def _calculate_attention(self, query, state): + raise NotImplementedError( + "_calculate_attention need to be implemented by subclasses.") + + def compute_mask(self, inputs, mask=None): + # There real input of the attention is query and state, and the memory layer + # mask shouldn't be pass down. Returning None for all output mask here. + return None, None + + def get_config(self): + config = {} + # Since the probability_fn is likely to be a wrapped function, the child + # class should preserve the original function and how its wrapped. + + if self.query_layer is not None: + config["query_layer"] = { + "class_name": self.query_layer.__class__.__name__, + "config": self.query_layer.get_config(), + } + if self.memory_layer is not None: + config["memory_layer"] = { + "class_name": self.memory_layer.__class__.__name__, + "config": self.memory_layer.get_config(), + } + # memory is a required init parameter and its a tensor. It cannot be + # serialized to config, so we put a placeholder for it. + config["memory"] = None + base_config = super(_BaseAttentionMechanismV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def _process_probability_fn(self, func_name): + """Helper method to retrieve the probably function by string input.""" + valid_probability_fns = { + "softmax": nn_ops.softmax, + "hardmax": hardmax, + } + if func_name not in valid_probability_fns.keys(): + raise ValueError("Invalid probability function: %s, options are %s" % + (func_name, valid_probability_fns.keys())) + return valid_probability_fns[func_name] + + @classmethod + def deserialize_inner_layer_from_config(cls, config, custom_objects): + """Helper method that reconstruct the query and memory from the config. + + In the get_config() method, the query and memory layer configs are + serialized into dict for persistence, this method perform the reverse action + to reconstruct the layer from the config. + + Args: + config: dict, the configs that will be used to reconstruct the object. + custom_objects: dict mapping class names (or function names) of custom + (non-Keras) objects to class/functions. + Returns: + config: dict, the config with layer instance created, which is ready to be + used as init parameters. + """ + # Reconstruct the query and memory layer for parent class. + from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + # Instead of updating the input, create a copy and use that. + config = config.copy() + query_layer_config = config.pop("query_layer", None) + if query_layer_config: + query_layer = deserialize_layer(query_layer_config, + custom_objects=custom_objects) + config["query_layer"] = query_layer + memory_layer_config = config.pop("memory_layer", None) + if memory_layer_config: + memory_layer = deserialize_layer(memory_layer_config, + custom_objects=custom_objects) + config["memory_layer"] = memory_layer + return config + + @property + def alignments_size(self): + return self._alignments_size + + @property + def state_size(self): + return self._alignments_size + + def initial_alignments(self, batch_size, dtype): + """Creates the initial alignment values for the `AttentionWrapper` class. + + This is important for AttentionMechanisms that use the previous alignment + to calculate the alignment at the next time step (e.g. monotonic attention). + + The default behavior is to return a tensor of all zeros. + + Args: + batch_size: `int32` scalar, the batch_size. + dtype: The `dtype`. + + Returns: + A `dtype` tensor shaped `[batch_size, alignments_size]` + (`alignments_size` is the values' `max_time`). + """ + max_time = self._alignments_size + return _zero_state_tensors(max_time, batch_size, dtype) + + def initial_state(self, batch_size, dtype): + """Creates the initial state values for the `AttentionWrapper` class. + + This is important for AttentionMechanisms that use the previous alignment + to calculate the alignment at the next time step (e.g. monotonic attention). + + The default behavior is to return the same output as initial_alignments. + + Args: + batch_size: `int32` scalar, the batch_size. + dtype: The `dtype`. + + Returns: + A structure of all-zero tensors with shapes as described by `state_size`. + """ + return self.initial_alignments(batch_size, dtype) + + def _luong_score(query, keys, scale): """Implements Luong-style (multiplicative) scoring function. @@ -304,7 +609,7 @@ def _luong_score(query, keys, scale): Args: query: Tensor, shape `[batch_size, num_units]` to compare to keys. keys: Processed memory, shape `[batch_size, max_time, num_units]`. - scale: Whether to apply a scale to the score function. + scale: the optional tensor to scale the attention score. Returns: A `[batch_size, max_time]` tensor of unnormalized score values. @@ -320,7 +625,6 @@ def _luong_score(query, keys, scale): "Query (%s) has units: %s. Keys (%s) have units: %s. " "Perhaps you need to set num_units to the keys' dimension (%s)?" % (query, depth, keys, key_units, key_units)) - dtype = query.dtype # Reshape from [batch_size, depth] to [batch_size, 1, depth] # for matmul. @@ -338,12 +642,8 @@ def _luong_score(query, keys, scale): score = math_ops.matmul(query, keys, transpose_b=True) score = array_ops.squeeze(score, [1]) - if scale: - # Scalar used in weight scaling - g = variable_scope.get_variable( - "attention_g", dtype=dtype, - initializer=init_ops.ones_initializer, shape=()) - score = g * score + if scale is not None: + score = scale * score return score @@ -354,8 +654,8 @@ class LuongAttention(_BaseAttentionMechanism): as described in: Minh-Thang Luong, Hieu Pham, Christopher D. Manning. - "Effective Approaches to Attention-based Neural Machine Translation." - EMNLP 2015. https://arxiv.org/abs/1508.04025 + [Effective Approaches to Attention-based Neural Machine Translation. + EMNLP 2015.](https://arxiv.org/abs/1508.04025) The second is the scaled form inspired partly by the normalized form of Bahdanau attention. @@ -429,13 +729,133 @@ class LuongAttention(_BaseAttentionMechanism): `max_time`). """ with variable_scope.variable_scope(None, "luong_attention", [query]): - score = _luong_score(query, self._keys, self._scale) + attention_g = None + if self._scale: + attention_g = variable_scope.get_variable( + "attention_g", dtype=query.dtype, + initializer=init_ops.ones_initializer, shape=()) + score = _luong_score(query, self._keys, attention_g) alignments = self._probability_fn(score, state) next_state = alignments return alignments, next_state -def _bahdanau_score(processed_query, keys, normalize): +class LuongAttentionV2(_BaseAttentionMechanismV2): + """Implements Luong-style (multiplicative) attention scoring. + + This attention has two forms. The first is standard Luong attention, + as described in: + + Minh-Thang Luong, Hieu Pham, Christopher D. Manning. + [Effective Approaches to Attention-based Neural Machine Translation. + EMNLP 2015.](https://arxiv.org/abs/1508.04025) + + The second is the scaled form inspired partly by the normalized form of + Bahdanau attention. + + To enable the second form, construct the object with parameter + `scale=True`. + """ + + def __init__(self, + units, + memory, + memory_sequence_length=None, + scale=False, + probability_fn="softmax", + dtype=None, + name="LuongAttention", + **kwargs): + """Construct the AttentionMechanism mechanism. + + Args: + units: The depth of the attention mechanism. + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length: (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. + scale: Python boolean. Whether to scale the energy term. + probability_fn: (optional) string, the name of function to convert the + attention score to probabilities. The default is `softmax` which is + `tf.nn.softmax`. Other options is `hardmax`, which is hardmax() within + this module. Any other value will result intovalidation error. Default + to use `softmax`. + dtype: The data type for the memory layer of the attention mechanism. + name: Name to use when creating ops. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + # For LuongAttention, we only transform the memory layer; thus + # num_units **must** match expected the query depth. + self.probability_fn_name = probability_fn + probability_fn = self._process_probability_fn(self.probability_fn_name) + wrapped_probability_fn = lambda score, _: probability_fn(score) + if dtype is None: + dtype = dtypes.float32 + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) + self.units = units + self.scale = scale + self.scale_weight = None + super(LuongAttentionV2, self).__init__( + memory=memory, + memory_sequence_length=memory_sequence_length, + query_layer=None, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs) + + def build(self, input_shape): + super(LuongAttentionV2, self).build(input_shape) + if self.scale and self.scale_weight is None: + self.scale_weight = self.add_weight( + "attention_g", initializer=init_ops.ones_initializer, shape=()) + self.built = True + + def _calculate_attention(self, query, state): + """Score the query based on the keys and values. + + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + state: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + next_state: Same as the alignments. + """ + score = _luong_score(query, self.keys, self.scale_weight) + alignments = self.probability_fn(score, state) + next_state = alignments + return alignments, next_state + + def get_config(self): + config = { + "units": self.units, + "scale": self.scale, + "probability_fn": self.probability_fn_name, + } + base_config = super(LuongAttentionV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( + config, custom_objects=custom_objects) + return cls(**config) + + +def _bahdanau_score(processed_query, keys, attention_v, + attention_g=None, attention_b=None): """Implements Bahdanau-style (additive) scoring function. This attention has two forms. The first is Bhandanau attention, @@ -453,41 +873,28 @@ def _bahdanau_score(processed_query, keys, normalize): Training of Deep Neural Networks." https://arxiv.org/abs/1602.07868 - To enable the second form, set `normalize=True`. + To enable the second form, set please pass in attention_g and attention_b. Args: processed_query: Tensor, shape `[batch_size, num_units]` to compare to keys. keys: Processed memory, shape `[batch_size, max_time, num_units]`. - normalize: Whether to normalize the score function. + attention_v: Tensor, shape `[num_units]`. + attention_g: Optional scalar tensor for normalization. + attention_b: Optional tensor with shape `[num_units]` for normalization. Returns: A `[batch_size, max_time]` tensor of unnormalized score values. """ - dtype = processed_query.dtype - # Get the number of hidden units from the trailing dimension of keys - num_units = tensor_shape.dimension_value( - keys.shape[2]) or array_ops.shape(keys)[2] # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting. processed_query = array_ops.expand_dims(processed_query, 1) - v = variable_scope.get_variable( - "attention_v", [num_units], dtype=dtype) - if normalize: - # Scalar used in weight normalization - g = variable_scope.get_variable( - "attention_g", dtype=dtype, - initializer=init_ops.constant_initializer(math.sqrt((1. / num_units))), - shape=()) - # Bias added prior to the nonlinearity - b = variable_scope.get_variable( - "attention_b", [num_units], dtype=dtype, - initializer=init_ops.zeros_initializer()) - # normed_v = g * v / ||v|| - normed_v = g * v * math_ops.rsqrt( - math_ops.reduce_sum(math_ops.square(v))) + if attention_g is not None and attention_b is not None: + normed_v = attention_g * attention_v * math_ops.rsqrt( + math_ops.reduce_sum(math_ops.square(attention_v))) return math_ops.reduce_sum( - normed_v * math_ops.tanh(keys + processed_query + b), [2]) + normed_v * math_ops.tanh(keys + processed_query + attention_b), [2]) else: - return math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), [2]) + return math_ops.reduce_sum( + attention_v * math_ops.tanh(keys + processed_query), [2]) class BahdanauAttention(_BaseAttentionMechanism): @@ -578,12 +985,169 @@ class BahdanauAttention(_BaseAttentionMechanism): """ with variable_scope.variable_scope(None, "bahdanau_attention", [query]): processed_query = self.query_layer(query) if self.query_layer else query - score = _bahdanau_score(processed_query, self._keys, self._normalize) + attention_v = variable_scope.get_variable( + "attention_v", [self._num_units], dtype=query.dtype) + if not self._normalize: + attention_g = None + attention_b = None + else: + attention_g = variable_scope.get_variable( + "attention_g", dtype=query.dtype, + initializer=init_ops.constant_initializer( + math.sqrt((1. / self._num_units))), + shape=()) + attention_b = variable_scope.get_variable( + "attention_b", [self._num_units], dtype=query.dtype, + initializer=init_ops.zeros_initializer()) + + score = _bahdanau_score(processed_query, self._keys, attention_v, + attention_g=attention_g, attention_b=attention_b) alignments = self._probability_fn(score, state) next_state = alignments return alignments, next_state +class BahdanauAttentionV2(_BaseAttentionMechanismV2): + """Implements Bahdanau-style (additive) attention. + + This attention has two forms. The first is Bahdanau attention, + as described in: + + Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. + "Neural Machine Translation by Jointly Learning to Align and Translate." + ICLR 2015. https://arxiv.org/abs/1409.0473 + + The second is the normalized form. This form is inspired by the + weight normalization article: + + Tim Salimans, Diederik P. Kingma. + "Weight Normalization: A Simple Reparameterization to Accelerate + Training of Deep Neural Networks." + https://arxiv.org/abs/1602.07868 + + To enable the second form, construct the object with parameter + `normalize=True`. + """ + + def __init__(self, + units, + memory, + memory_sequence_length=None, + normalize=False, + probability_fn="softmax", + kernel_initializer="glorot_uniform", + dtype=None, + name="BahdanauAttention", + **kwargs): + """Construct the Attention mechanism. + + Args: + units: The depth of the query mechanism. + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length: (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. + normalize: Python boolean. Whether to normalize the energy term. + probability_fn: (optional) string, the name of function to convert the + attention score to probabilities. The default is `softmax` which is + `tf.nn.softmax`. Other options is `hardmax`, which is hardmax() within + this module. Any other value will result into validation error. Default + to use `softmax`. + kernel_initializer: (optional), the name of the initializer for the + attention kernel. + dtype: The data type for the query and memory layers of the attention + mechanism. + name: Name to use when creating ops. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + self.probability_fn_name = probability_fn + probability_fn = self._process_probability_fn(self.probability_fn_name) + wrapped_probability_fn = lambda score, _: probability_fn(score) + if dtype is None: + dtype = dtypes.float32 + query_layer = kwargs.pop("query_layer", None) + if not query_layer: + query_layer = layers.Dense( + units, name="query_layer", use_bias=False, dtype=dtype) + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) + self.units = units + self.normalize = normalize + self.kernel_initializer = initializers.get(kernel_initializer) + self.attention_v = None + self.attention_g = None + self.attention_b = None + super(BahdanauAttentionV2, self).__init__( + memory=memory, + memory_sequence_length=memory_sequence_length, + query_layer=query_layer, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs) + + def build(self, input_shape): + super(BahdanauAttentionV2, self).build(input_shape) + if self.attention_v is None: + self.attention_v = self.add_weight( + "attention_v", [self.units], + dtype=self.dtype, + initializer=self.kernel_initializer) + if self.normalize and self.attention_g is None and self.attention_b is None: + self.attention_g = self.add_weight( + "attention_g", initializer=init_ops.constant_initializer( + math.sqrt((1. / self.units))), shape=()) + self.attention_b = self.add_weight( + "attention_b", shape=[self.units], + initializer=init_ops.zeros_initializer()) + self.built = True + + def _calculate_attention(self, query, state): + """Score the query based on the keys and values. + + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + state: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + next_state: same as alignments. + """ + processed_query = self.query_layer(query) if self.query_layer else query + score = _bahdanau_score(processed_query, self.keys, self.attention_v, + attention_g=self.attention_g, + attention_b=self.attention_b) + alignments = self.probability_fn(score, state) + next_state = alignments + return alignments, next_state + + def get_config(self): + config = { + "units": self.units, + "normalize": self.normalize, + "probability_fn": self.probability_fn_name, + "kernel_initializer": initializers.serialize(self.kernel_initializer) + } + base_config = super(BahdanauAttentionV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( + config, custom_objects=custom_objects) + return cls(**config) + + def safe_cumprod(x, *args, **kwargs): """Computes cumprod of x in logspace using cumsum to avoid underflow. @@ -766,6 +1330,34 @@ class _BaseMonotonicAttentionMechanism(_BaseAttentionMechanism): dtype=dtype) +class _BaseMonotonicAttentionMechanismV2(_BaseAttentionMechanismV2): + """Base attention mechanism for monotonic attention. + + Simply overrides the initial_alignments function to provide a dirac + distribution, which is needed in order for the monotonic attention + distributions to have the correct behavior. + """ + + def initial_alignments(self, batch_size, dtype): + """Creates the initial alignment values for the monotonic attentions. + + Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0] + for all entries in the batch. + + Args: + batch_size: `int32` scalar, the batch_size. + dtype: The `dtype`. + + Returns: + A `dtype` tensor shaped `[batch_size, alignments_size]` + (`alignments_size` is the values' `max_time`). + """ + max_time = self._alignments_size + return array_ops.one_hot( + array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time, + dtype=dtype) + + class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): """Monotonic attention mechanism with Bahadanau-style energy function. @@ -860,7 +1452,22 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): with variable_scope.variable_scope( None, "bahdanau_monotonic_attention", [query]): processed_query = self.query_layer(query) if self.query_layer else query - score = _bahdanau_score(processed_query, self._keys, self._normalize) + attention_v = variable_scope.get_variable( + "attention_v", [self._num_units], dtype=query.dtype) + if not self._normalize: + attention_g = None + attention_b = None + else: + attention_g = variable_scope.get_variable( + "attention_g", dtype=query.dtype, + initializer=init_ops.constant_initializer( + math.sqrt((1. / self._num_units))), + shape=()) + attention_b = variable_scope.get_variable( + "attention_b", [self._num_units], dtype=query.dtype, + initializer=init_ops.zeros_initializer()) + score = _bahdanau_score(processed_query, self._keys, attention_v, + attention_g=attention_g, attention_b=attention_b) score_bias = variable_scope.get_variable( "attention_score_bias", dtype=processed_query.dtype, initializer=self._score_bias_init) @@ -870,6 +1477,164 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): return alignments, next_state +class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): + """Monotonic attention mechanism with Bahadanau-style energy function. + + This type of attention enforces a monotonic constraint on the attention + distributions; that is once the model attends to a given point in the memory + it can't attend to any prior points at subsequence output timesteps. It + achieves this by using the _monotonic_probability_fn instead of softmax to + construct its attention distributions. Since the attention scores are passed + through a sigmoid, a learnable scalar bias parameter is applied after the + score function and before the sigmoid. Otherwise, it is equivalent to + BahdanauAttention. This approach is proposed in + + Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, + "Online and Linear-Time Attention by Enforcing Monotonic Alignments." + ICML 2017. https://arxiv.org/abs/1704.00784 + """ + + def __init__(self, + units, + memory, + memory_sequence_length=None, + normalize=False, + sigmoid_noise=0., + sigmoid_noise_seed=None, + score_bias_init=0., + mode="parallel", + kernel_initializer="glorot_uniform", + dtype=None, + name="BahdanauMonotonicAttention", + **kwargs): + """Construct the Attention mechanism. + + Args: + units: The depth of the query mechanism. + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length: (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. + normalize: Python boolean. Whether to normalize the energy term. + sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring + for `_monotonic_probability_fn` for more information. + sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. + score_bias_init: Initial value for score bias scalar. It's recommended to + initialize this to a negative value when the length of the memory is + large. + mode: How to compute the attention distribution. Must be one of + 'recursive', 'parallel', or 'hard'. See the docstring for + `tf.contrib.seq2seq.monotonic_attention` for more information. + kernel_initializer: (optional), the name of the initializer for the + attention kernel. + dtype: The data type for the query and memory layers of the attention + mechanism. + name: Name to use when creating ops. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + # Set up the monotonic probability fn with supplied parameters + if dtype is None: + dtype = dtypes.float32 + wrapped_probability_fn = functools.partial( + _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, + seed=sigmoid_noise_seed) + query_layer = kwargs.pop("query_layer", None) + if not query_layer: + query_layer = layers.Dense( + units, name="query_layer", use_bias=False, dtype=dtype) + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) + self.units = units + self.normalize = normalize + self.sigmoid_noise = sigmoid_noise + self.sigmoid_noise_seed = sigmoid_noise_seed + self.score_bias_init = score_bias_init + self.mode = mode + self.kernel_initializer = initializers.get(kernel_initializer) + self.attention_v = None + self.attention_score_bias = None + self.attention_g = None + self.attention_b = None + super(BahdanauMonotonicAttentionV2, self).__init__( + memory=memory, + memory_sequence_length=memory_sequence_length, + query_layer=query_layer, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs) + + def build(self, input_shape): + super(BahdanauMonotonicAttentionV2, self).build(input_shape) + if self.attention_v is None: + self.attention_v = self.add_weight( + "attention_v", [self.units], dtype=self.dtype, + initializer=self.kernel_initializer) + if self.attention_score_bias is None: + self.attention_score_bias = self.add_weight( + "attention_score_bias", shape=(), dtype=self.dtype, + initializer=init_ops.constant_initializer( + self.score_bias_init, dtype=self.dtype)) + if self.normalize and self.attention_g is None and self.attention_b is None: + self.attention_g = self.add_weight( + "attention_g", dtype=self.dtype, + initializer=init_ops.constant_initializer( + math.sqrt((1. / self.units))), + shape=()) + self.attention_b = self.add_weight( + "attention_b", [self.units], dtype=self.dtype, + initializer=init_ops.zeros_initializer()) + self.built = True + + def _calculate_attention(self, query, state): + """Score the query based on the keys and values. + + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + state: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + """ + processed_query = self.query_layer(query) if self.query_layer else query + score = _bahdanau_score(processed_query, self.keys, self.attention_v, + attention_g=self.attention_g, + attention_b=self.attention_b) + score += self.attention_score_bias + alignments = self.probability_fn(score, state) + next_state = alignments + return alignments, next_state + + def get_config(self): + config = { + "units": self.units, + "normalize": self.normalize, + "sigmoid_noise": self.sigmoid_noise, + "sigmoid_noise_seed": self.sigmoid_noise_seed, + "score_bias_init": self.score_bias_init, + "mode": self.mode, + "kernel_initializer": initializers.serialize(self.kernel_initializer), + } + base_config = super(BahdanauMonotonicAttentionV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( + config, custom_objects=custom_objects) + return cls(**config) + + class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): """Monotonic attention mechanism with Luong-style energy function. @@ -960,7 +1725,12 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): """ with variable_scope.variable_scope(None, "luong_monotonic_attention", [query]): - score = _luong_score(query, self._keys, self._scale) + attention_g = None + if self._scale: + attention_g = variable_scope.get_variable( + "attention_g", dtype=query.dtype, + initializer=init_ops.ones_initializer, shape=()) + score = _luong_score(query, self._keys, attention_g) score_bias = variable_scope.get_variable( "attention_score_bias", dtype=query.dtype, initializer=self._score_bias_init) @@ -970,6 +1740,139 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): return alignments, next_state +class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): + """Monotonic attention mechanism with Luong-style energy function. + + This type of attention enforces a monotonic constraint on the attention + distributions; that is once the model attends to a given point in the memory + it can't attend to any prior points at subsequence output timesteps. It + achieves this by using the _monotonic_probability_fn instead of softmax to + construct its attention distributions. Otherwise, it is equivalent to + LuongAttention. This approach is proposed in + + [Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, + "Online and Linear-Time Attention by Enforcing Monotonic Alignments." + ICML 2017.](https://arxiv.org/abs/1704.00784) + """ + + def __init__(self, + units, + memory, + memory_sequence_length=None, + scale=False, + sigmoid_noise=0., + sigmoid_noise_seed=None, + score_bias_init=0., + mode="parallel", + dtype=None, + name="LuongMonotonicAttention", + **kwargs): + """Construct the Attention mechanism. + + Args: + units: The depth of the query mechanism. + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length: (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. + scale: Python boolean. Whether to scale the energy term. + sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring + for `_monotonic_probability_fn` for more information. + sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. + score_bias_init: Initial value for score bias scalar. It's recommended to + initialize this to a negative value when the length of the memory is + large. + mode: How to compute the attention distribution. Must be one of + 'recursive', 'parallel', or 'hard'. See the docstring for + `tf.contrib.seq2seq.monotonic_attention` for more information. + dtype: The data type for the query and memory layers of the attention + mechanism. + name: Name to use when creating ops. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + # Set up the monotonic probability fn with supplied parameters + if dtype is None: + dtype = dtypes.float32 + wrapped_probability_fn = functools.partial( + _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, + seed=sigmoid_noise_seed) + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) + self.units = units + self.scale = scale + self.sigmoid_noise = sigmoid_noise + self.sigmoid_noise_seed = sigmoid_noise_seed + self.score_bias_init = score_bias_init + self.mode = mode + self.attention_g = None + self.attention_score_bias = None + super(LuongMonotonicAttentionV2, self).__init__( + memory=memory, + memory_sequence_length=memory_sequence_length, + query_layer=None, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs) + + def build(self, input_shape): + super(LuongMonotonicAttentionV2, self).build(input_shape) + if self.scale and self.attention_g is None: + self.attention_g = self.add_weight( + "attention_g", initializer=init_ops.ones_initializer, shape=()) + if self.attention_score_bias is None: + self.attention_score_bias = self.add_weight( + "attention_score_bias", shape=(), + initializer=init_ops.constant_initializer( + self.score_bias_init, dtype=self.dtype)) + self.built = True + + def _calculate_attention(self, query, state): + """Score the query based on the keys and values. + + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + state: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + next_state: Same as alignments + """ + score = _luong_score(query, self.keys, self.attention_g) + score += self.attention_score_bias + alignments = self.probability_fn(score, state) + next_state = alignments + return alignments, next_state + + def get_config(self): + config = { + "units": self.units, + "scale": self.scale, + "sigmoid_noise": self.sigmoid_noise, + "sigmoid_noise_seed": self.sigmoid_noise_seed, + "score_bias_init": self.score_bias_init, + "mode": self.mode, + } + base_config = super(LuongMonotonicAttentionV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( + config, custom_objects=custom_objects) + return cls(**config) + + class AttentionWrapperState( collections.namedtuple("AttentionWrapperState", ("cell_state", "attention", "time", "alignments", @@ -1017,7 +1920,15 @@ class AttentionWrapperState( def with_same_shape(old, new): """Check and set new tensor's shape.""" if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor): - return tensor_util.with_same_shape(old, new) + if not context.executing_eagerly(): + return tensor_util.with_same_shape(old, new) + else: + if old.shape.as_list() != new.shape.as_list(): + raise ValueError("The shape of the AttentionWrapperState is " + "expected to be same as the one to clone. " + "self.shape: %s, input.shape: %s" % + (old.shape, new.shape)) + return new return new return nest.map_structure( @@ -1026,6 +1937,82 @@ class AttentionWrapperState( super(AttentionWrapperState, self)._replace(**kwargs)) +def _prepare_memory(memory, memory_sequence_length=None, memory_mask=None, + check_inner_dims_defined=True): + """Convert to tensor and possibly mask `memory`. + + Args: + memory: `Tensor`, shaped `[batch_size, max_time, ...]`. + memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`. + memory_mask: `boolean` tensor with shape [batch_size, max_time]. The memory + should be skipped when the corresponding mask is False. + check_inner_dims_defined: Python boolean. If `True`, the `memory` + argument's shape is checked to ensure all but the two outermost + dimensions are fully defined. + + Returns: + A (possibly masked), checked, new `memory`. + + Raises: + ValueError: If `check_inner_dims_defined` is `True` and not + `memory.shape[2:].is_fully_defined()`. + """ + memory = nest.map_structure( + lambda m: ops.convert_to_tensor(m, name="memory"), memory) + if memory_sequence_length is not None and memory_mask is not None: + raise ValueError("memory_sequence_length and memory_mask can't be provided " + "at same time.") + if memory_sequence_length is not None: + memory_sequence_length = ops.convert_to_tensor( + memory_sequence_length, name="memory_sequence_length") + if check_inner_dims_defined: + def _check_dims(m): + if not m.get_shape()[2:].is_fully_defined(): + raise ValueError("Expected memory %s to have fully defined inner dims, " + "but saw shape: %s" % (m.name, m.get_shape())) + nest.map_structure(_check_dims, memory) + if memory_sequence_length is None and memory_mask is None: + return memory + elif memory_sequence_length is not None: + seq_len_mask = array_ops.sequence_mask( + memory_sequence_length, + maxlen=array_ops.shape(nest.flatten(memory)[0])[1], + dtype=nest.flatten(memory)[0].dtype) + else: + # For memory_mask is not None + seq_len_mask = math_ops.cast( + memory_mask, dtype=nest.flatten(memory)[0].dtype) + def _maybe_mask(m, seq_len_mask): + """Mask the memory based on the memory mask.""" + rank = m.get_shape().ndims + rank = rank if rank is not None else array_ops.rank(m) + extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) + seq_len_mask = array_ops.reshape( + seq_len_mask, + array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) + return m * seq_len_mask + + return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) + + +def _maybe_mask_score(score, memory_sequence_length=None, memory_mask=None, + score_mask_value=None): + """Mask the attention score based on the masks.""" + if memory_sequence_length is None and memory_mask is None: + return score + if memory_sequence_length is not None and memory_mask is not None: + raise ValueError("memory_sequence_length and memory_mask can't be provided " + "at same time.") + if memory_sequence_length is not None: + message = "All values in memory_sequence_length must greater than zero." + with ops.control_dependencies( + [check_ops.assert_positive(memory_sequence_length, message=message)]): + memory_mask = array_ops.sequence_mask( + memory_sequence_length, maxlen=array_ops.shape(score)[1]) + score_mask_values = score_mask_value * array_ops.ones_like(score) + return array_ops.where(memory_mask, score, score_mask_values) + + def hardmax(logits, name=None): """Returns batched one-hot vectors. @@ -1050,8 +2037,14 @@ def hardmax(logits, name=None): def _compute_attention(attention_mechanism, cell_output, attention_state, attention_layer): """Computes the attention and alignments for a given attention_mechanism.""" - alignments, next_attention_state = attention_mechanism( - cell_output, state=attention_state) + if isinstance(attention_mechanism, _BaseAttentionMechanismV2): + alignments, next_attention_state = attention_mechanism( + [cell_output, attention_state]) + else: + # For other class, assume they are following _BaseAttentionMechanism, which + # takes query and state as separate parameter. + alignments, next_attention_state = attention_mechanism( + cell_output, state=attention_state) # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] expanded_alignments = array_ops.expand_dims(alignments, 1) @@ -1064,13 +2057,13 @@ def _compute_attention(attention_mechanism, cell_output, attention_state, # the batched matmul is over memory_time, so the output shape is # [batch_size, 1, memory_size]. # we then squeeze out the singleton dim. - context = math_ops.matmul(expanded_alignments, attention_mechanism.values) - context = array_ops.squeeze(context, [1]) + context_ = math_ops.matmul(expanded_alignments, attention_mechanism.values) + context_ = array_ops.squeeze(context_, [1]) if attention_layer is not None: - attention = attention_layer(array_ops.concat([cell_output, context], 1)) + attention = attention_layer(array_ops.concat([cell_output, context_], 1)) else: - attention = context + attention = context_ return attention, alignments, next_attention_state @@ -1088,7 +2081,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): output_attention=True, initial_cell_state=None, name=None, - attention_layer=None): + attention_layer=None, + attention_fn=None): """Construct the `AttentionWrapper`. **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in @@ -1132,7 +2126,9 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, attention_layer_size must be a list of the same length. If - attention_layer is set, this must be None. + attention_layer is set, this must be None. If attention_fn is set, + it must guaranteed that the outputs of attention_fn also meet the + above requirements. alignment_history: Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time major `TensorArray` on which you must call `stack()`). @@ -1158,6 +2154,12 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): the context as attention at each time step. If attention_mechanism is a list, attention_layer must be a list of the same length. If attention_layers_size is set, this must be None. + attention_fn: An optional callable function that allows users to provide + their own customized attention function, which takes input + (attention_mechanism, cell_output, attention_state, attention_layer) and + outputs (attention, alignments, next_attention_state). If provided, + the attention_layer_size should be the size of the outputs of + attention_fn. Raises: TypeError: `attention_layer_size` is not None and (`attention_mechanism` @@ -1240,6 +2242,10 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): tensor_shape.dimension_value(attention_mechanism.values.shape[-1]) for attention_mechanism in attention_mechanisms) + if attention_fn is None: + attention_fn = _compute_attention + self._attention_fn = attention_fn + self._cell = cell self._attention_mechanisms = attention_mechanisms self._cell_input_fn = cell_input_fn @@ -1443,7 +2449,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): all_attention_states = [] maybe_all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): - attention, alignments, next_attention_state = _compute_attention( + attention, alignments, next_attention_state = self._attention_fn( attention_mechanism, cell_output, previous_attention_state[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py index 7eb95e5a70de985dca0d4b565ba03bdf454b6161..16dfa7ed8268d761dee49ec0146efabcaaef1393 100644 --- a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py @@ -23,8 +23,10 @@ import collections from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.contrib.seq2seq.python.ops import helper as helper_py +from tensorflow.contrib.seq2seq.python.ops import sampler as sampler_py from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import layers from tensorflow.python.layers import base as layers_base from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.util import nest @@ -146,3 +148,102 @@ class BasicDecoder(decoder.Decoder): sample_ids=sample_ids) outputs = BasicDecoderOutput(cell_outputs, sample_ids) return (outputs, next_state, next_inputs, finished) + + +class BasicDecoderV2(decoder.BaseDecoder): + """Basic sampling decoder.""" + + def __init__(self, cell, sampler, output_layer=None, **kwargs): + """Initialize BasicDecoder. + + Args: + cell: An `RNNCell` instance. + sampler: A `Sampler` instance. + output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., + `tf.layers.Dense`. Optional layer to apply to the RNN output prior to + storing the result or sampling. + **kwargs: Other keyward arguments for layer creation. + + Raises: + TypeError: if `cell`, `helper` or `output_layer` have an incorrect type. + """ + rnn_cell_impl.assert_like_rnncell("cell", cell) + if not isinstance(sampler, sampler_py.Sampler): + raise TypeError("sampler must be a Sampler, received: %s" % (sampler,)) + if (output_layer is not None and + not isinstance(output_layer, layers.Layer)): + raise TypeError( + "output_layer must be a Layer, received: %s" % (output_layer,)) + self.cell = cell + self.sampler = sampler + self.output_layer = output_layer + super(BasicDecoderV2, self).__init__(**kwargs) + + def initialize(self, inputs, initial_state=None, **kwargs): + """Initialize the decoder.""" + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + self._cell_dtype = nest.flatten(initial_state)[0].dtype + return self.sampler.initialize(inputs, **kwargs) + (initial_state,) + + @property + def batch_size(self): + return self.sampler.batch_size + + def _rnn_output_size(self): + size = tensor_shape.TensorShape(self.cell.output_size) + if self.output_layer is None: + return size + else: + # To use layer's compute_output_shape, we need to convert the + # RNNCell's output_size entries into shapes with an unknown + # batch size. We then pass this through the layer's + # compute_output_shape and read off all but the first (batch) + # dimensions to get the output size of the rnn with the layer + # applied to the top. + output_shape_with_unknown_batch = nest.map_structure( + lambda s: tensor_shape.TensorShape([None]).concatenate(s), size) + layer_output_shape = self.output_layer.compute_output_shape( + output_shape_with_unknown_batch) + return nest.map_structure(lambda s: s[1:], layer_output_shape) + + @property + def output_size(self): + # Return the cell output and the id + return BasicDecoderOutput( + rnn_output=self._rnn_output_size(), + sample_id=self.sampler.sample_ids_shape) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and the sample_ids_dtype from the helper. + dtype = self._cell_dtype + return BasicDecoderOutput( + nest.map_structure(lambda _: dtype, self._rnn_output_size()), + self.sampler.sample_ids_dtype) + + def step(self, time, inputs, state): + """Perform a decoding step. + + Args: + time: scalar `int32` tensor. + inputs: A (structure of) input tensors. + state: A (structure of) state tensors and TensorArrays. + + Returns: + `(outputs, next_state, next_inputs, finished)`. + """ + cell_outputs, cell_state = self.cell(inputs, state) + if self.output_layer is not None: + cell_outputs = self.output_layer(cell_outputs) + sample_ids = self.sampler.sample( + time=time, outputs=cell_outputs, state=cell_state) + (finished, next_inputs, next_state) = self.sampler.next_inputs( + time=time, + outputs=cell_outputs, + state=cell_state, + sample_ids=sample_ids) + outputs = BasicDecoderOutput(cell_outputs, sample_ids) + return (outputs, next_state, next_inputs, finished) diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index ab36848f13ab3078cd232c18f140188e12db703b..1d773a449890cd7335b2225db39d79ca958a3276 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -24,11 +24,12 @@ import numpy as np from tensorflow.contrib.seq2seq.python.ops import attention_wrapper from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.layers import base as layers_base +from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops @@ -182,11 +183,12 @@ def gather_tree_from_array(t, parent_ids, sequence_length): return ordered -def _check_maybe(t): +def _check_ndims(t): if t.shape.ndims is None: raise ValueError( "Expected tensor (%s) to have known rank, but ndims == None." % t) + def _check_static_batch_beam_maybe(shape, batch_size, beam_width): """Raises an exception if dimensions are known statically and can not be reshaped to [batch_size, beam_size, -1]. @@ -205,6 +207,7 @@ def _check_static_batch_beam_maybe(shape, batch_size, beam_width): return False return True + def _check_batch_beam(t, batch_size, beam_width): """Returns an Assert operation checking that the elements of the stacked TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point, @@ -229,70 +232,30 @@ def _check_batch_beam(t, batch_size, beam_width): return control_flow_ops.Assert(condition, [error_message]) +class BeamSearchDecoderMixin(object): + """BeamSearchDecoderMixin contains the common methods for BeamSearchDecoder. -class BeamSearchDecoder(decoder.Decoder): - """BeamSearch sampling decoder. - - **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in - `AttentionWrapper`, then you must ensure that: - - - The encoder output has been tiled to `beam_width` via - `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). - - The `batch_size` argument passed to the `zero_state` method of this - wrapper is equal to `true_batch_size * beam_width`. - - The initial state created with `zero_state` above contains a - `cell_state` value containing properly tiled final state from the - encoder. - - An example: - - ``` - tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( - encoder_outputs, multiplier=beam_width) - tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( - encoder_final_state, multiplier=beam_width) - tiled_sequence_length = tf.contrib.seq2seq.tile_batch( - sequence_length, multiplier=beam_width) - attention_mechanism = MyFavoriteAttentionMechanism( - num_units=attention_depth, - memory=tiled_inputs, - memory_sequence_length=tiled_sequence_length) - attention_cell = AttentionWrapper(cell, attention_mechanism, ...) - decoder_initial_state = attention_cell.zero_state( - dtype, batch_size=true_batch_size * beam_width) - decoder_initial_state = decoder_initial_state.clone( - cell_state=tiled_encoder_final_state) - ``` - - Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use - when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages - the translation to cover all inputs. + It is expected to be used a base class for concrete BeamSearchDecoder. Since + this is a mixin class, it is expected to be used together with other class as + base. """ def __init__(self, cell, - embedding, - start_tokens, - end_token, - initial_state, beam_width, output_layer=None, length_penalty_weight=0.0, coverage_penalty_weight=0.0, - reorder_tensor_arrays=True): - """Initialize the BeamSearchDecoder. + reorder_tensor_arrays=True, + **kwargs): + """Initialize the BeamSearchDecoderMixin. Args: cell: An `RNNCell` instance. - embedding: A callable that takes a vector tensor of `ids` (argmax ids), - or the `params` argument for `embedding_lookup`. - start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. - end_token: `int32` scalar, the token that marks end of decoding. - initial_state: A (possibly nested tuple of...) tensors and TensorArrays. beam_width: Python integer, the number of beams. - output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., - `tf.layers.Dense`. Optional layer to apply to the RNN output prior - to storing the result or sampling. + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. coverage_penalty_weight: Float weight to penalize the coverage of source sentence. Disabled with 0.0. @@ -302,59 +265,35 @@ class BeamSearchDecoder(decoder.Decoder): Otherwise, the `TensorArray` will be returned as is. Set this flag to `False` if the cell state contains `TensorArray`s that are not amenable to reordering. + **kwargs: Dict, other keyword arguments for parent class. Raises: TypeError: if `cell` is not an instance of `RNNCell`, - or `output_layer` is not an instance of `tf.layers.Layer`. - ValueError: If `start_tokens` is not a vector or - `end_token` is not a scalar. + or `output_layer` is not an instance of `tf.keras.layers.Layer`. """ rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access if (output_layer is not None and - not isinstance(output_layer, layers_base.Layer)): + not isinstance(output_layer, layers.Layer)): raise TypeError( "output_layer must be a Layer, received: %s" % type(output_layer)) self._cell = cell self._output_layer = output_layer self._reorder_tensor_arrays = reorder_tensor_arrays - if callable(embedding): - self._embedding_fn = embedding - else: - self._embedding_fn = ( - lambda ids: embedding_ops.embedding_lookup(embedding, ids)) - - self._start_tokens = ops.convert_to_tensor( - start_tokens, dtype=dtypes.int32, name="start_tokens") - if self._start_tokens.get_shape().ndims != 1: - raise ValueError("start_tokens must be a vector") - self._end_token = ops.convert_to_tensor( - end_token, dtype=dtypes.int32, name="end_token") - if self._end_token.get_shape().ndims != 0: - raise ValueError("end_token must be a scalar") - - self._batch_size = array_ops.size(start_tokens) + self._start_tokens = None + self._end_token = None + self._batch_size = None self._beam_width = beam_width self._length_penalty_weight = length_penalty_weight self._coverage_penalty_weight = coverage_penalty_weight - self._initial_cell_state = nest.map_structure( - self._maybe_split_batch_beams, initial_state, self._cell.state_size) - self._start_tokens = array_ops.tile( - array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) - self._start_inputs = self._embedding_fn(self._start_tokens) - - self._finished = array_ops.one_hot( - array_ops.zeros([self._batch_size], dtype=dtypes.int32), - depth=self._beam_width, - on_value=False, - off_value=True, - dtype=dtypes.bool) + super(BeamSearchDecoderMixin, self).__init__(**kwargs) @property def batch_size(self): return self._batch_size def _rnn_output_size(self): + """Get the output shape from the RNN layer.""" size = self._cell.output_size if self._output_layer is None: return size @@ -393,50 +332,6 @@ class BeamSearchDecoder(decoder.Decoder): predicted_ids=tensor_shape.TensorShape([self._beam_width]), parent_ids=tensor_shape.TensorShape([self._beam_width])) - @property - def output_dtype(self): - # Assume the dtype of the cell is the output_size structure - # containing the input_state's first component's dtype. - # Return that structure and int32 (the id) - dtype = nest.flatten(self._initial_cell_state)[0].dtype - return BeamSearchDecoderOutput( - scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), - predicted_ids=dtypes.int32, - parent_ids=dtypes.int32) - - def initialize(self, name=None): - """Initialize the decoder. - - Args: - name: Name scope for any created operations. - - Returns: - `(finished, start_inputs, initial_state)`. - """ - finished, start_inputs = self._finished, self._start_inputs - - dtype = nest.flatten(self._initial_cell_state)[0].dtype - log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) - array_ops.zeros([self._batch_size], dtype=dtypes.int32), - depth=self._beam_width, - on_value=ops.convert_to_tensor(0.0, dtype=dtype), - off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), - dtype=dtype) - init_attention_probs = get_attention_probs( - self._initial_cell_state, self._coverage_penalty_weight) - if init_attention_probs is None: - init_attention_probs = () - - initial_state = BeamSearchDecoderState( - cell_state=self._initial_cell_state, - log_probs=log_probs, - finished=finished, - lengths=array_ops.zeros( - [self._batch_size, self._beam_width], dtype=dtypes.int64), - accumulated_attention_probs=init_attention_probs) - - return (finished, start_inputs, initial_state) - def finalize(self, outputs, final_state, sequence_lengths): """Finalize and return the predicted_ids. @@ -562,7 +457,7 @@ class BeamSearchDecoder(decoder.Decoder): """ if isinstance(t, tensor_array_ops.TensorArray): return t - _check_maybe(t) + _check_ndims(t) if t.shape.ndims >= 1: return self._split_batch_beams(t, s) else: @@ -586,7 +481,7 @@ class BeamSearchDecoder(decoder.Decoder): """ if isinstance(t, tensor_array_ops.TensorArray): return t - _check_maybe(t) + _check_ndims(t) if t.shape.ndims >= 2: return self._merge_batch_beams(t, s) else: @@ -609,11 +504,18 @@ class BeamSearchDecoder(decoder.Decoder): if not isinstance(t, tensor_array_ops.TensorArray): return t # pylint: disable=protected-access - if (not t._infer_shape or not t._element_shape - or t._element_shape[0].ndims is None - or t._element_shape[0].ndims < 1): + # This is a bad hack due to the implementation detail of eager/graph TA. + # TODO(b/124374427): Update this to use public property of TensorArray. + if context.executing_eagerly(): + element_shape = t._element_shape + else: + element_shape = t._element_shape[0] + if (not t._infer_shape + or not t._element_shape + or element_shape.ndims is None + or element_shape.ndims < 1): shape = ( - t._element_shape[0] if t._infer_shape and t._element_shape + element_shape if t._infer_shape and t._element_shape else tensor_shape.TensorShape(None)) tf_logging.warn("The TensorArray %s in the cell state is not amenable to " "sorting based on the beam search result. For a " @@ -621,10 +523,10 @@ class BeamSearchDecoder(decoder.Decoder): "defined and have at least a rank of 1, but saw shape: %s" % (t.handle.name, shape)) return t - shape = t._element_shape[0] # pylint: enable=protected-access if not _check_static_batch_beam_maybe( - shape, tensor_util.constant_value(self._batch_size), self._beam_width): + element_shape, tensor_util.constant_value(self._batch_size), + self._beam_width): return t t = t.stack() with ops.control_dependencies( @@ -684,6 +586,359 @@ class BeamSearchDecoder(decoder.Decoder): return (beam_search_output, beam_search_state, next_inputs, finished) +class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.Decoder): + # Note that the inheritance hierarchy is important here. The Mixin has to be + # the first parent class since we will use super().__init__(), and Mixin which + # is a object will properly invoke the __init__ method of other parent class. + """BeamSearch sampling decoder. + + **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in + `AttentionWrapper`, then you must ensure that: + + - The encoder output has been tiled to `beam_width` via + `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). + - The `batch_size` argument passed to the `zero_state` method of this + wrapper is equal to `true_batch_size * beam_width`. + - The initial state created with `zero_state` above contains a + `cell_state` value containing properly tiled final state from the + encoder. + + An example: + + ``` + tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( + encoder_outputs, multiplier=beam_width) + tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( + encoder_final_state, multiplier=beam_width) + tiled_sequence_length = tf.contrib.seq2seq.tile_batch( + sequence_length, multiplier=beam_width) + attention_mechanism = MyFavoriteAttentionMechanism( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + attention_cell = AttentionWrapper(cell, attention_mechanism, ...) + decoder_initial_state = attention_cell.zero_state( + dtype, batch_size=true_batch_size * beam_width) + decoder_initial_state = decoder_initial_state.clone( + cell_state=tiled_encoder_final_state) + ``` + + Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use + when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages + the decoder to cover all inputs. + """ + + def __init__(self, + cell, + embedding, + start_tokens, + end_token, + initial_state, + beam_width, + output_layer=None, + length_penalty_weight=0.0, + coverage_penalty_weight=0.0, + reorder_tensor_arrays=True): + """Initialize the BeamSearchDecoder. + + Args: + cell: An `RNNCell` instance. + embedding: A callable that takes a vector tensor of `ids` (argmax ids), + or the `params` argument for `embedding_lookup`. + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + beam_width: Python integer, the number of beams. + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. + length_penalty_weight: Float weight to penalize length. Disabled with 0.0. + coverage_penalty_weight: Float weight to penalize the coverage of source + sentence. Disabled with 0.0. + reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell + state will be reordered according to the beam search path. If the + `TensorArray` can be reordered, the stacked form will be returned. + Otherwise, the `TensorArray` will be returned as is. Set this flag to + `False` if the cell state contains `TensorArray`s that are not amenable + to reordering. + + Raises: + TypeError: if `cell` is not an instance of `RNNCell`, + or `output_layer` is not an instance of `tf.keras.layers.Layer`. + ValueError: If `start_tokens` is not a vector or + `end_token` is not a scalar. + """ + super(BeamSearchDecoder, self).__init__( + cell, + beam_width, + output_layer=output_layer, + length_penalty_weight=length_penalty_weight, + coverage_penalty_weight=coverage_penalty_weight, + reorder_tensor_arrays=reorder_tensor_arrays) + + if callable(embedding): + self._embedding_fn = embedding + else: + self._embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self._start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + if self._start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self._end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + + self._batch_size = array_ops.size(start_tokens) + self._initial_cell_state = nest.map_structure( + self._maybe_split_batch_beams, initial_state, self._cell.state_size) + self._start_tokens = array_ops.tile( + array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) + self._start_inputs = self._embedding_fn(self._start_tokens) + + self._finished = array_ops.one_hot( + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=False, + off_value=True, + dtype=dtypes.bool) + + def initialize(self, name=None): + """Initialize the decoder. + + Args: + name: Name scope for any created operations. + + Returns: + `(finished, start_inputs, initial_state)`. + """ + finished, start_inputs = self._finished, self._start_inputs + + dtype = nest.flatten(self._initial_cell_state)[0].dtype + log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=ops.convert_to_tensor(0.0, dtype=dtype), + off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), + dtype=dtype) + init_attention_probs = get_attention_probs( + self._initial_cell_state, self._coverage_penalty_weight) + if init_attention_probs is None: + init_attention_probs = () + + initial_state = BeamSearchDecoderState( + cell_state=self._initial_cell_state, + log_probs=log_probs, + finished=finished, + lengths=array_ops.zeros( + [self._batch_size, self._beam_width], dtype=dtypes.int64), + accumulated_attention_probs=init_attention_probs) + + return (finished, start_inputs, initial_state) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and int32 (the id) + dtype = nest.flatten(self._initial_cell_state)[0].dtype + return BeamSearchDecoderOutput( + scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), + predicted_ids=dtypes.int32, + parent_ids=dtypes.int32) + + +class BeamSearchDecoderV2(BeamSearchDecoderMixin, decoder.BaseDecoder): + # Note that the inheritance hierarchy is important here. The Mixin has to be + # the first parent class since we will use super().__init__(), and Mixin which + # is a object will properly invoke the __init__ method of other parent class. + """BeamSearch sampling decoder. + + **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in + `AttentionWrapper`, then you must ensure that: + + - The encoder output has been tiled to `beam_width` via + `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). + - The `batch_size` argument passed to the `zero_state` method of this + wrapper is equal to `true_batch_size * beam_width`. + - The initial state created with `zero_state` above contains a + `cell_state` value containing properly tiled final state from the + encoder. + + An example: + + ``` + tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( + encoder_outputs, multiplier=beam_width) + tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( + encoder_final_state, multiplier=beam_width) + tiled_sequence_length = tf.contrib.seq2seq.tile_batch( + sequence_length, multiplier=beam_width) + attention_mechanism = MyFavoriteAttentionMechanism( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + attention_cell = AttentionWrapper(cell, attention_mechanism, ...) + decoder_initial_state = attention_cell.zero_state( + dtype, batch_size=true_batch_size * beam_width) + decoder_initial_state = decoder_initial_state.clone( + cell_state=tiled_encoder_final_state) + ``` + + Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use + when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages + the decoding to cover all inputs. + """ + + def __init__(self, + cell, + beam_width, + embedding_fn=None, + output_layer=None, + length_penalty_weight=0.0, + coverage_penalty_weight=0.0, + reorder_tensor_arrays=True, + **kwargs): + """Initialize the BeamSearchDecoderV2. + + Args: + cell: An `RNNCell` instance. + beam_width: Python integer, the number of beams. + embedding_fn: A callable that takes a vector tensor of `ids` (argmax ids). + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. + length_penalty_weight: Float weight to penalize length. Disabled with 0.0. + coverage_penalty_weight: Float weight to penalize the coverage of source + sentence. Disabled with 0.0. + reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell + state will be reordered according to the beam search path. If the + `TensorArray` can be reordered, the stacked form will be returned. + Otherwise, the `TensorArray` will be returned as is. Set this flag to + `False` if the cell state contains `TensorArray`s that are not amenable + to reordering. + **kwargs: Dict, other keyword arguments for initialization. + + Raises: + TypeError: if `cell` is not an instance of `RNNCell`, + or `output_layer` is not an instance of `tf.keras.layers.Layer`. + """ + super(BeamSearchDecoderV2, self).__init__( + cell, + beam_width, + output_layer=output_layer, + length_penalty_weight=length_penalty_weight, + coverage_penalty_weight=coverage_penalty_weight, + reorder_tensor_arrays=reorder_tensor_arrays, + **kwargs) + + if embedding_fn is None or callable(embedding_fn): + self._embedding_fn = embedding_fn + else: + raise ValueError("embedding_fn is expected to be a callable, got %s" % + type(embedding_fn)) + + def initialize(self, + embedding, + start_tokens, + end_token, + initial_state): + """Initialize the decoder. + + Args: + embedding: A tensor from the embedding layer output, which is the + `params` argument for `embedding_lookup`. + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + Returns: + `(finished, start_inputs, initial_state)`. + Raises: + ValueError: If `start_tokens` is not a vector or `end_token` is not a + scalar. + """ + if embedding is not None and self._embedding_fn is not None: + raise ValueError( + "embedding and embedding_fn cannot be provided at same time") + elif embedding is not None: + self._embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self._start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + if self._start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self._end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + + self._batch_size = array_ops.size(start_tokens) + self._initial_cell_state = nest.map_structure( + self._maybe_split_batch_beams, initial_state, self._cell.state_size) + self._start_tokens = array_ops.tile( + array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) + self._start_inputs = self._embedding_fn(self._start_tokens) + + self._finished = array_ops.one_hot( + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=False, + off_value=True, + dtype=dtypes.bool) + + finished, start_inputs = self._finished, self._start_inputs + + dtype = nest.flatten(self._initial_cell_state)[0].dtype + log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=ops.convert_to_tensor(0.0, dtype=dtype), + off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), + dtype=dtype) + init_attention_probs = get_attention_probs( + self._initial_cell_state, self._coverage_penalty_weight) + if init_attention_probs is None: + init_attention_probs = () + + initial_state = BeamSearchDecoderState( + cell_state=self._initial_cell_state, + log_probs=log_probs, + finished=finished, + lengths=array_ops.zeros( + [self._batch_size, self._beam_width], dtype=dtypes.int64), + accumulated_attention_probs=init_attention_probs) + + return (finished, start_inputs, initial_state) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and int32 (the id) + dtype = nest.flatten(self._initial_cell_state)[0].dtype + return BeamSearchDecoderOutput( + scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), + predicted_ids=dtypes.int32, + parent_ids=dtypes.int32) + + def call(self, embeddning, start_tokens, end_token, initial_state, **kwargs): + init_kwargs = kwargs + init_kwargs["start_tokens"] = start_tokens + init_kwargs["end_token"] = end_token + init_kwargs["initial_state"] = initial_state + return decoder.dynamic_decode(self, + output_time_major=self.output_time_major, + impute_finished=self.impute_finished, + maximum_iterations=self.maximum_iterations, + parallel_iterations=self.parallel_iterations, + swap_memory=self.swap_memory, + decoder_init_input=embeddning, + decoder_init_kwargs=init_kwargs) + + def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, beam_width, end_token, length_penalty_weight, coverage_penalty_weight): @@ -921,6 +1176,7 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight, """ length_penalty_ = _length_penalty( sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight) + length_penalty_ = math_ops.cast(length_penalty_, dtype=log_probs.dtype) scores = log_probs / length_penalty_ coverage_penalty_weight = ops.convert_to_tensor( @@ -1067,7 +1323,7 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size, """ if isinstance(gather_from, tensor_array_ops.TensorArray): return gather_from - _check_maybe(gather_from) + _check_ndims(gather_from) if gather_from.shape.ndims >= len(gather_shape): return _tensor_gather_helper( gather_indices=gather_indices, diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index f58268eff525a4b592c79acb32207e1a3f62bdc7..33f7bac8159401175ce57c0463fff1398c1dd9bb 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util @@ -135,6 +136,127 @@ class Decoder(object): return False +class BaseDecoder(layers.Layer): + """An RNN Decoder that is based on a Keras layer. + + Concepts used by this interface: + - `inputs`: (structure of) tensors and TensorArrays that is passed as input to + the RNNCell composing the decoder, at each time step. + - `state`: (structure of) tensors and TensorArrays that is passed to the + RNNCell instance as the state. + - `memory`: (sturecute of) tensors that is usually the full output of the + encoder, which will be used for the attention wrapper for the RNNCell. + - `finished`: boolean tensor telling whether each sequence in the batch is + finished. + - `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at each + time step. + """ + + def __init__(self, + output_time_major=False, + impute_finished=False, + maximum_iterations=None, + parallel_iterations=32, + swap_memory=False, + **kwargs): + self.output_time_major = output_time_major + self.impute_finished = impute_finished + self.maximum_iterations = maximum_iterations + self.parallel_iterations = parallel_iterations + self.swap_memory = swap_memory + super(BaseDecoder, self).__init__(**kwargs) + + def call(self, inputs, initial_state=None, **kwargs): + init_kwargs = kwargs + init_kwargs["initial_state"] = initial_state + return dynamic_decode(self, + output_time_major=self.output_time_major, + impute_finished=self.impute_finished, + maximum_iterations=self.maximum_iterations, + parallel_iterations=self.parallel_iterations, + swap_memory=self.swap_memory, + decoder_init_input=inputs, + decoder_init_kwargs=init_kwargs) + + @property + def batch_size(self): + """The batch size of input values.""" + raise NotImplementedError + + @property + def output_size(self): + """A (possibly nested tuple of...) integer[s] or `TensorShape` object[s].""" + raise NotImplementedError + + @property + def output_dtype(self): + """A (possibly nested tuple of...) dtype[s].""" + raise NotImplementedError + + def initialize(self, inputs, initial_state=None, **kwargs): + """Called before any decoding iterations. + + This methods must compute initial input values and initial state. + + Args: + inputs: (structure of) tensors that contains the input for the decoder. In + the normal case, its a tensor with shape [batch, timestep, embedding]. + initial_state: (structure of) tensors that contains the initial state for + the RNNCell. + **kwargs: Other arguments that are passed in from layer.call() method. It + could contains item like input sequence_length, or masking for input. + + Returns: + `(finished, initial_inputs, initial_state)`: initial values of + 'finished' flags, inputs and state. + """ + raise NotImplementedError + + def step(self, time, inputs, state): + """Called per step of decoding (but only once for dynamic decoding). + + Args: + time: Scalar `int32` tensor. Current step number. + inputs: RNNCell input (possibly nested tuple of) tensor[s] for this time + step. + state: RNNCell state (possibly nested tuple of) tensor[s] from previous + time step. + + Returns: + `(outputs, next_state, next_inputs, finished)`: `outputs` is an object + containing the decoder output, `next_state` is a (structure of) state + tensors and TensorArrays, `next_inputs` is the tensor that should be used + as input for the next step, `finished` is a boolean tensor telling whether + the sequence is complete, for each sequence in the batch. + """ + raise NotImplementedError + + def finalize(self, outputs, final_state, sequence_lengths): + raise NotImplementedError + + @property + def tracks_own_finished(self): + """Describes whether the Decoder keeps track of finished states. + + Most decoders will emit a true/false `finished` value independently + at each time step. In this case, the `dynamic_decode` function keeps track + of which batch entries are already finished, and performs a logical OR to + insert new batches to the finished set. + + Some decoders, however, shuffle batches / beams between time steps and + `dynamic_decode` will mix up the finished state across these entries because + it does not track the reshuffle across time steps. In this case, it is + up to the decoder to declare that it will keep track of its own finished + state by setting this property to `True`. + + Returns: + Python bool. + """ + return False + + # TODO(scottzhu): Add build/get_config/from_config and other layer methods. + + def _create_zero_outputs(size, dtype, batch_size): """Create a zero outputs Tensor structure.""" def _create(s, d): @@ -149,7 +271,8 @@ def dynamic_decode(decoder, maximum_iterations=None, parallel_iterations=32, swap_memory=False, - scope=None): + scope=None, + **kwargs): """Perform dynamic decoding with `decoder`. Calls initialize() once and step() repeatedly on the Decoder object. @@ -171,6 +294,9 @@ def dynamic_decode(decoder, parallel_iterations: Argument passed to `tf.while_loop`. swap_memory: Argument passed to `tf.while_loop`. scope: Optional variable scope to use. + **kwargs: dict, other keyword arguments for dynamic_decode. It might contain + arguments for `BaseDecoder` to initialize, which takes all tensor inputs + during call(). Returns: `(final_outputs, final_state, final_sequence_lengths)`. @@ -179,7 +305,7 @@ def dynamic_decode(decoder, TypeError: if `decoder` is not an instance of `Decoder`. ValueError: if `maximum_iterations` is provided but is not a scalar. """ - if not isinstance(decoder, Decoder): + if not isinstance(decoder, (Decoder, BaseDecoder)): raise TypeError("Expected decoder to be type Decoder, but saw: %s" % type(decoder)) @@ -204,7 +330,14 @@ def dynamic_decode(decoder, if maximum_iterations.get_shape().ndims != 0: raise ValueError("maximum_iterations must be a scalar") - initial_finished, initial_inputs, initial_state = decoder.initialize() + if isinstance(decoder, Decoder): + initial_finished, initial_inputs, initial_state = decoder.initialize() + else: + # For BaseDecoder that takes tensor inputs during call. + decoder_init_input = kwargs.pop("decoder_init_input", None) + decoder_init_kwargs = kwargs.pop("decoder_init_kwargs", {}) + initial_finished, initial_inputs, initial_state = decoder.initialize( + decoder_init_input, **decoder_init_kwargs) zero_outputs = _create_zero_outputs(decoder.output_size, decoder.output_dtype, @@ -222,7 +355,7 @@ def dynamic_decode(decoder, def _shape(batch_size, from_shape): if (not isinstance(from_shape, tensor_shape.TensorShape) or from_shape.ndims == 0): - return tensor_shape.TensorShape(None) + return None else: batch_size = tensor_util.constant_value( ops.convert_to_tensor( diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index 3245cc5e72154289ea3ba000b9a30586a7ad03a9..033c2eb0801d5a51ee937f5e960faa91a6f1ae54 100644 --- a/tensorflow/contrib/seq2seq/python/ops/helper.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -32,9 +32,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.ops.distributions import bernoulli -from tensorflow.python.ops.distributions import categorical from tensorflow.python.util import nest __all__ = [ @@ -51,6 +50,68 @@ __all__ = [ _transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access +# The following sample functions (_call_sampler, bernoulli_sample, +# categorical_sample) mimic TensorFlow Probability distribution semantics. + + +def _call_sampler(sample_n_fn, sample_shape, name=None): + """Reshapes vector of samples.""" + with ops.name_scope(name, "call_sampler", values=[sample_shape]): + sample_shape = ops.convert_to_tensor( + sample_shape, dtype=dtypes.int32, name="sample_shape") + # Ensure sample_shape is a vector (vs just a scalar). + pad = math_ops.cast(math_ops.equal(array_ops.rank(sample_shape), 0), + dtypes.int32) + sample_shape = array_ops.reshape( + sample_shape, + array_ops.pad(array_ops.shape(sample_shape), + paddings=[[pad, 0]], + constant_values=1)) + samples = sample_n_fn(math_ops.reduce_prod(sample_shape)) + batch_event_shape = array_ops.shape(samples)[1:] + final_shape = array_ops.concat([sample_shape, batch_event_shape], 0) + return array_ops.reshape(samples, final_shape) + + +def bernoulli_sample(probs=None, logits=None, dtype=dtypes.int32, + sample_shape=(), seed=None): + """Samples from Bernoulli distribution.""" + if probs is None: + probs = math_ops.sigmoid(logits, name="probs") + else: + probs = ops.convert_to_tensor(probs, name="probs") + batch_shape_tensor = array_ops.shape(probs) + def _sample_n(n): + """Sample vector of Bernoullis.""" + new_shape = array_ops.concat([[n], batch_shape_tensor], 0) + uniform = random_ops.random_uniform( + new_shape, seed=seed, dtype=probs.dtype) + return math_ops.cast(math_ops.less(uniform, probs), dtype) + return _call_sampler(_sample_n, sample_shape) + + +def categorical_sample(logits, dtype=dtypes.int32, + sample_shape=(), seed=None): + """Samples from categorical distribution.""" + logits = ops.convert_to_tensor(logits, name="logits") + event_size = array_ops.shape(logits)[-1] + batch_shape_tensor = array_ops.shape(logits)[:-1] + def _sample_n(n): + """Sample vector of categoricals.""" + if logits.shape.ndims == 2: + logits_2d = logits + else: + logits_2d = array_ops.reshape(logits, [-1, event_size]) + sample_dtype = dtypes.int64 if logits.dtype.size > 4 else dtypes.int32 + draws = random_ops.multinomial( + logits_2d, n, seed=seed, output_dtype=sample_dtype) + draws = array_ops.reshape( + array_ops.transpose(draws), + array_ops.concat([[n], batch_shape_tensor], 0)) + return math_ops.cast(draws, dtype) + return _call_sampler(_sample_n, sample_shape) + + def _unstack_ta(inp): return tensor_array_ops.TensorArray( dtype=inp.dtype, size=array_ops.shape(inp)[0], @@ -307,14 +368,14 @@ class ScheduledEmbeddingTrainingHelper(TrainingHelper): with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample", [time, outputs, state]): # Return -1s where we did not sample, and sample_ids elsewhere - select_sampler = bernoulli.Bernoulli( - probs=self._sampling_probability, dtype=dtypes.bool) - select_sample = select_sampler.sample( - sample_shape=self.batch_size, seed=self._scheduling_seed) - sample_id_sampler = categorical.Categorical(logits=outputs) + select_sample = bernoulli_sample( + probs=self._sampling_probability, + dtype=dtypes.bool, + sample_shape=self.batch_size, + seed=self._scheduling_seed) return array_ops.where( select_sample, - sample_id_sampler.sample(seed=self._seed), + categorical_sample(logits=outputs, seed=self._seed), gen_array_ops.fill([self.batch_size], -1)) def next_inputs(self, time, outputs, state, sample_ids, name=None): @@ -425,8 +486,10 @@ class ScheduledOutputTrainingHelper(TrainingHelper): def sample(self, time, outputs, state, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperSample", [time, outputs, state]): - sampler = bernoulli.Bernoulli(probs=self._sampling_probability) - return sampler.sample(sample_shape=self.batch_size, seed=self._seed) + return bernoulli_sample( + probs=self._sampling_probability, + sample_shape=self.batch_size, + seed=self._seed) def next_inputs(self, time, outputs, state, sample_ids, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs", @@ -610,8 +673,7 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper): else: logits = outputs / self._softmax_temperature - sample_id_sampler = categorical.Categorical(logits=logits) - sample_ids = sample_id_sampler.sample(seed=self._seed) + sample_ids = categorical_sample(logits=logits, seed=self._seed) return sample_ids diff --git a/tensorflow/contrib/seq2seq/python/ops/loss.py b/tensorflow/contrib/seq2seq/python/ops/loss.py index 39a6d2f58b140706a94d83273d3327edd1891368..0fbfd6187030f14ac105a18b3e09b7a42d4de32a 100644 --- a/tensorflow/contrib/seq2seq/python/ops/loss.py +++ b/tensorflow/contrib/seq2seq/python/ops/loss.py @@ -20,11 +20,12 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.keras.losses import Loss from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -__all__ = ["sequence_loss"] +__all__ = ["sequence_loss", "SequenceLoss"] def sequence_loss(logits, @@ -32,16 +33,26 @@ def sequence_loss(logits, weights, average_across_timesteps=True, average_across_batch=True, + sum_over_timesteps=False, + sum_over_batch=False, softmax_loss_function=None, name=None): """Weighted cross-entropy loss for a sequence of logits. - Depending on the values of `average_across_timesteps` and - `average_across_batch`, the return Tensor will have rank 0, 1, or 2 as these - arguments reduce the cross-entropy at each target, which has shape - `[batch_size, sequence_length]`, over their respective dimensions. For - example, if `average_across_timesteps` is `True` and `average_across_batch` - is `False`, then the return Tensor will have shape `[batch_size]`. + Depending on the values of `average_across_timesteps` / `sum_over_timesteps` + and `average_across_batch` / `sum_over_batch`, the return Tensor will have + rank 0, 1, or 2 as these arguments reduce the cross-entropy at each target, + which has shape `[batch_size, sequence_length]`, over their respective + dimensions. For example, if `average_across_timesteps` is `True` and + `average_across_batch` is `False`, then the return Tensor will have shape + `[batch_size]`. + + Note that `average_across_timesteps` and `sum_over_timesteps` cannot be True + at same time. Same for `average_across_batch` and `sum_over_batch`. + + The recommended loss reduction in tf 2.0 has been changed to sum_over, instead + of weighted average. User are recommend to use `sum_over_timesteps` and + `sum_over_batch` for reduction. Args: logits: A Tensor of shape @@ -58,6 +69,12 @@ def sequence_loss(logits, dimension and divide the cost by the total label weight across timesteps. average_across_batch: If set, sum the cost across the batch dimension and divide the returned cost by the batch size. + sum_over_timesteps: If set, sum the cost across the sequence dimension and + divide the size of the sequence. Note that any element with 0 weights will + be excluded from size calculation. + sum_over_batch: if set, sum the cost across the batch dimension and divide + the total cost by the batch size. Not that any element with 0 weights will + be excluded from size calculation. softmax_loss_function: Function (labels, logits) -> loss-batch to be used instead of the standard softmax (the default if this is None). **Note that to avoid confusion, it is required for the function to accept @@ -78,11 +95,15 @@ def sequence_loss(logits, raise ValueError("Logits must be a " "[batch_size x sequence_length x logits] tensor") if len(targets.get_shape()) != 2: - raise ValueError("Targets must be a [batch_size x sequence_length] " - "tensor") + raise ValueError("Targets must be a [batch_size x sequence_length] tensor") if len(weights.get_shape()) != 2: - raise ValueError("Weights must be a [batch_size x sequence_length] " - "tensor") + raise ValueError("Weights must be a [batch_size x sequence_length] tensor") + if average_across_timesteps and sum_over_timesteps: + raise ValueError("average_across_timesteps and sum_over_timesteps cannot " + "be set to True at same time.") + if average_across_batch and sum_over_batch: + raise ValueError("average_across_batch and sum_over_batch cannot be set " + "to True at same time.") with ops.name_scope(name, "sequence_loss", [logits, targets, weights]): num_classes = array_ops.shape(logits)[2] logits_flat = array_ops.reshape(logits, [-1, num_classes]) @@ -96,20 +117,56 @@ def sequence_loss(logits, if average_across_timesteps and average_across_batch: crossent = math_ops.reduce_sum(crossent) total_size = math_ops.reduce_sum(weights) - total_size += 1e-12 # to avoid division by 0 for all-0 weights - crossent /= total_size + crossent = math_ops.div_no_nan(crossent, total_size) + elif sum_over_timesteps and sum_over_batch: + crossent = math_ops.reduce_sum(crossent) + total_count = math_ops.cast(math_ops.count_nonzero(weights), + crossent.dtype) + crossent = math_ops.div_no_nan(crossent, total_count) else: - batch_size = array_ops.shape(logits)[0] - sequence_length = array_ops.shape(logits)[1] - crossent = array_ops.reshape(crossent, [batch_size, sequence_length]) - if average_across_timesteps and not average_across_batch: - crossent = math_ops.reduce_sum(crossent, axis=[1]) - total_size = math_ops.reduce_sum(weights, axis=[1]) - total_size += 1e-12 # to avoid division by 0 for all-0 weights - crossent /= total_size - if not average_across_timesteps and average_across_batch: - crossent = math_ops.reduce_sum(crossent, axis=[0]) - total_size = math_ops.reduce_sum(weights, axis=[0]) - total_size += 1e-12 # to avoid division by 0 for all-0 weights - crossent /= total_size + crossent = array_ops.reshape(crossent, array_ops.shape(logits)[0:2]) + if average_across_timesteps or average_across_batch: + reduce_axis = [0] if average_across_batch else [1] + crossent = math_ops.reduce_sum(crossent, axis=reduce_axis) + total_size = math_ops.reduce_sum(weights, axis=reduce_axis) + crossent = math_ops.div_no_nan(crossent, total_size) + elif sum_over_timesteps or sum_over_batch: + reduce_axis = [0] if sum_over_batch else [1] + crossent = math_ops.reduce_sum(crossent, axis=reduce_axis) + total_count = math_ops.cast( + math_ops.count_nonzero(weights, axis=reduce_axis), + dtype=crossent.dtype) + crossent = math_ops.div_no_nan(crossent, total_count) return crossent + + +class SequenceLoss(Loss): + """Weighted cross-entropy loss for a sequence of logits.""" + + def __init__(self, + average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=True, + softmax_loss_function=None, + name=None): + super(SequenceLoss, self).__init__(name=name) + self.average_across_timesteps = average_across_timesteps + self.average_across_batch = average_across_batch + self.sum_over_timesteps = sum_over_timesteps + self.sum_over_batch = sum_over_batch + self.softmax_loss_function = softmax_loss_function + + def __call__(self, y_true, y_pred, sample_weight=None): + """Override the parent __call__ to have a customized reduce behavior.""" + return sequence_loss(y_pred, y_true, sample_weight, + average_across_timesteps=self.average_across_timesteps, + average_across_batch=self.average_across_batch, + sum_over_timesteps=self.sum_over_timesteps, + sum_over_batch=self.sum_over_batch, + softmax_loss_function=self.softmax_loss_function, + name=self.name) + + def call(self, y_true, y_pred): + # Skip this method since the __call__ contains real implementation. + pass diff --git a/tensorflow/contrib/seq2seq/python/ops/sampler.py b/tensorflow/contrib/seq2seq/python/ops/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3e48b3bc61c0ff94ae0a1794767c7ff6914969 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/ops/sampler.py @@ -0,0 +1,765 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A library of sampler for use with SamplingDecoders.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.util import nest + +__all__ = [ + "Sampler", + "TrainingSampler", + "GreedyEmbeddingSampler", + "SampleEmbeddingSampler", + "CustomSampler", + "ScheduledEmbeddingTrainingSampler", + "ScheduledOutputTrainingSampler", + "InferenceSampler", +] + +_transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access + + +@six.add_metaclass(abc.ABCMeta) +class Sampler(object): + """Interface for implementing sampling in seq2seq decoders. + + Sampler instances are used by `BasicDecoder`. The normal usage of a sampler is + like below: + sampler = Sampler(init_args) + (initial_finished, initial_inputs) = sampler.initialize(input_tensors) + for time_step in range(time): + cell_output, cell_state = cell.call(cell_input, previous_state) + sample_ids = sampler.sample(time_step, cell_output, cell_state) + (finished, next_inputs, next_state) = sampler.next_inputs( + time_step,cell_output, cell_state) + + Note that all the tensor input should not be feed to Sampler as __init__() + parameters, instead, they should be feed by decoders via initialize(). + """ + + @abc.abstractmethod + def initialize(self, inputs, **kwargs): + """initialize the sampler with the input tensors. + + This method suppose to be only invoke once before the calling other methods + of the Sampler. + + Args: + inputs: A (structure of) input tensors, it could be a nested tuple or a + single tensor. + **kwargs: Other kwargs for initialization. It could contain tensors like + mask for inputs, or non tensor parameter. + + Returns: + `(initial_finished, initial_inputs)`. + """ + pass + + @abc.abstractmethod + def sample(self, time, outputs, state): + """Returns `sample_ids`.""" + pass + + @abc.abstractmethod + def next_inputs(self, time, outputs, state, sample_ids): + """Returns `(finished, next_inputs, next_state)`.""" + pass + + @abc.abstractproperty + def batch_size(self): + """Batch size of tensor returned by `sample`. + + Returns a scalar int32 tensor. The return value might not available before + the invocation of initialize(), in this case, ValueError is raised. + """ + raise NotImplementedError("batch_size has not been implemented") + + @abc.abstractproperty + def sample_ids_shape(self): + """Shape of tensor returned by `sample`, excluding the batch dimension. + + Returns a `TensorShape`. The return value might not available before the + invocation of initialize(). + """ + raise NotImplementedError("sample_ids_shape has not been implemented") + + @abc.abstractproperty + def sample_ids_dtype(self): + """DType of tensor returned by `sample`. + + Returns a DType. The return value might not available before the + invocation of initialize(). + """ + raise NotImplementedError("sample_ids_dtype has not been implemented") + + +class CustomSampler(Sampler): + """Base abstract class that allows the user to customize sampling.""" + + def __init__(self, + initialize_fn, + sample_fn, + next_inputs_fn, + sample_ids_shape=None, + sample_ids_dtype=None): + """Initializer. + + Args: + initialize_fn: callable that returns `(finished, next_inputs)` for the + first iteration. + sample_fn: callable that takes `(time, outputs, state)` and emits tensor + `sample_ids`. + next_inputs_fn: callable that takes `(time, outputs, state, sample_ids)` + and emits `(finished, next_inputs, next_state)`. + sample_ids_shape: Either a list of integers, or a 1-D Tensor of type + `int32`, the shape of each value in the `sample_ids` batch. Defaults to + a scalar. + sample_ids_dtype: The dtype of the `sample_ids` tensor. Defaults to int32. + """ + self._initialize_fn = initialize_fn + self._sample_fn = sample_fn + self._next_inputs_fn = next_inputs_fn + self._batch_size = None + self._sample_ids_shape = tensor_shape.TensorShape(sample_ids_shape or []) + self._sample_ids_dtype = sample_ids_dtype or dtypes.int32 + + @property + def batch_size(self): + if self._batch_size is None: + raise ValueError("batch_size accessed before initialize was called") + return self._batch_size + + @property + def sample_ids_shape(self): + return self._sample_ids_shape + + @property + def sample_ids_dtype(self): + return self._sample_ids_dtype + + def initialize(self, inputs, **kwargs): + (finished, next_inputs) = self._initialize_fn(inputs, **kwargs) + if self._batch_size is None: + self._batch_size = array_ops.size(finished) + return (finished, next_inputs) + + def sample(self, time, outputs, state): + return self._sample_fn(time=time, outputs=outputs, state=state) + + def next_inputs(self, time, outputs, state, sample_ids): + return self._next_inputs_fn( + time=time, outputs=outputs, state=state, sample_ids=sample_ids) + + +class TrainingSampler(Sampler): + """A Sampler for use during training. + + Only reads inputs. + + Returned sample_ids are the argmax of the RNN output logits. + """ + + def __init__(self, time_major=False): + """Initializer. + + Args: + time_major: Python bool. Whether the tensors in `inputs` are time major. + If `False` (default), they are assumed to be batch major. + + Raises: + ValueError: if `sequence_length` is not a 1D tensor. + """ + self.time_major = time_major + self._batch_size = None + + @property + def batch_size(self): + if self._batch_size is None: + raise ValueError("batch_size accessed before initialize was called") + return self._batch_size + + @property + def sample_ids_shape(self): + return tensor_shape.TensorShape([]) + + @property + def sample_ids_dtype(self): + return dtypes.int32 + + def initialize(self, inputs, sequence_length=None): + """Initialize the TrainSampler. + + Args: + inputs: A (structure of) input tensors. + sequence_length: An int32 vector tensor. + + Returns: + (finished, next_inputs), a tuple of two items. The first item is a boolean + vector to indicate whether the item in the batch has finished. The + second item is the first slide of input data based on the timestep + dimension (usually the second dim of the input). + """ + self.inputs = ops.convert_to_tensor(inputs, name="inputs") + if not self.time_major: + inputs = nest.map_structure(_transpose_batch_time, inputs) + + self.input_tas = nest.map_structure(_unstack_ta, inputs) + if sequence_length is None: + raise ValueError("sequence_length is required for TrainingSampler") + self.sequence_length = ops.convert_to_tensor( + sequence_length, name="sequence_length") + if self.sequence_length.get_shape().ndims != 1: + raise ValueError( + "Expected sequence_length to be a vector, but received shape: %s" % + self._sequence_length.get_shape()) + + self.zero_inputs = nest.map_structure( + lambda inp: array_ops.zeros_like(inp[0, :]), inputs) + + self._batch_size = array_ops.size(self.sequence_length) + + finished = math_ops.equal(0, self.sequence_length) + all_finished = math_ops.reduce_all(finished) + next_inputs = control_flow_ops.cond( + all_finished, + lambda: self.zero_inputs, + lambda: nest.map_structure(lambda inp: inp.read(0), self.input_tas)) + return (finished, next_inputs) + + def sample(self, time, outputs, state): + del state + sample_ids = math_ops.cast(math_ops.argmax(outputs, axis=-1), dtypes.int32) + return sample_ids + + def next_inputs(self, time, outputs, state, sample_ids): + del sample_ids + next_time = time + 1 + finished = (next_time >= self.sequence_length) + all_finished = math_ops.reduce_all(finished) + + def read_from_ta(inp): + return inp.read(next_time) + + next_inputs = control_flow_ops.cond( + all_finished, + lambda: self.zero_inputs, + lambda: nest.map_structure(read_from_ta, self.input_tas)) + return (finished, next_inputs, state) + + +class ScheduledEmbeddingTrainingSampler(TrainingSampler): + """A training sampler that adds scheduled sampling. + + Returns -1s for sample_ids where no sampling took place; valid sample id + values elsewhere. + """ + + def __init__(self, + sampling_probability, + embedding_fn=None, + time_major=False, + seed=None, + scheduling_seed=None): + """Initializer. + + Args: + sampling_probability: A `float32` 0-D or 1-D tensor: the probability of + sampling categorically from the output ids instead of reading directly + from the inputs. + embedding_fn: A callable that takes a vector tensor of `ids` (argmax ids), + or the `params` argument for `embedding_lookup`. + time_major: Python bool. Whether the tensors in `inputs` are time major. + If `False` (default), they are assumed to be batch major. + seed: The sampling seed. + scheduling_seed: The schedule decision rule sampling seed. + + Raises: + ValueError: if `sampling_probability` is not a scalar or vector. + """ + if callable(embedding_fn) or embedding_fn is None: + self.embedding_fn = embedding_fn + else: + raise ValueError("embedding_fn is expected to be callable, got %s" + % type(embedding_fn)) + self.sampling_probability = ops.convert_to_tensor( + sampling_probability, name="sampling_probability") + if self.sampling_probability.get_shape().ndims not in (0, 1): + raise ValueError( + "sampling_probability must be either a scalar or a vector. " + "saw shape: %s" % (self.sampling_probability.get_shape())) + self.seed = seed + self.scheduling_seed = scheduling_seed + super(ScheduledEmbeddingTrainingSampler, + self).__init__(time_major=time_major) + + def initialize(self, inputs, sequence_length=None, embedding=None): + if self.embedding_fn is None: + if embedding is None: + raise ValueError("embedding is required as a keyword argument for " + "ScheduledEmbeddingTrainingSampler") + self.embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + return super(ScheduledEmbeddingTrainingSampler, self).initialize( + inputs, sequence_length=sequence_length) + + def sample(self, time, outputs, state): + del state + # Return -1s where we did not sample, and sample_ids elsewhere + select_sample = bernoulli_sample( + probs=self.sampling_probability, + dtype=dtypes.bool, + sample_shape=self.batch_size, + seed=self.scheduling_seed) + return array_ops.where(select_sample, + categorical_sample(logits=outputs, seed=self.seed), + gen_array_ops.fill([self.batch_size], -1)) + + def next_inputs(self, time, outputs, state, sample_ids): + (finished, base_next_inputs, state) = ( + super(ScheduledEmbeddingTrainingSampler, self).next_inputs( + time=time, outputs=outputs, state=state, sample_ids=sample_ids)) + + def maybe_sample(): + """Perform scheduled sampling.""" + where_sampling = math_ops.cast( + array_ops.where(sample_ids > -1), dtypes.int32) + where_not_sampling = math_ops.cast( + array_ops.where(sample_ids <= -1), dtypes.int32) + sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling) + inputs_not_sampling = array_ops.gather_nd(base_next_inputs, + where_not_sampling) + sampled_next_inputs = self.embedding_fn(sample_ids_sampling) + base_shape = array_ops.shape(base_next_inputs) + return (array_ops.scatter_nd( + indices=where_sampling, updates=sampled_next_inputs, shape=base_shape) + + array_ops.scatter_nd( + indices=where_not_sampling, + updates=inputs_not_sampling, + shape=base_shape)) + + all_finished = math_ops.reduce_all(finished) + next_inputs = control_flow_ops.cond(all_finished, lambda: base_next_inputs, + maybe_sample) + return (finished, next_inputs, state) + + +class ScheduledOutputTrainingSampler(TrainingSampler): + """A training sampler that adds scheduled sampling directly to outputs. + + Returns False for sample_ids where no sampling took place; True elsewhere. + """ + + def __init__(self, + sampling_probability, + time_major=False, + seed=None, + next_inputs_fn=None): + """Initializer. + + Args: + sampling_probability: A `float32` scalar tensor: the probability of + sampling from the outputs instead of reading directly from the inputs. + time_major: Python bool. Whether the tensors in `inputs` are time major. + If `False` (default), they are assumed to be batch major. + seed: The sampling seed. + next_inputs_fn: (Optional) callable to apply to the RNN outputs to create + the next input when sampling. If `None` (default), the RNN outputs will + be used as the next inputs. + + Raises: + ValueError: if `sampling_probability` is not a scalar or vector. + """ + self.sampling_probability = ops.convert_to_tensor( + sampling_probability, name="sampling_probability") + if self.sampling_probability.get_shape().ndims not in (0, 1): + raise ValueError( + "sampling_probability must be either a scalar or a vector. " + "saw shape: %s" % (self._sampling_probability.get_shape())) + + self.seed = seed + self.next_inputs_fn = next_inputs_fn + + super(ScheduledOutputTrainingSampler, self).__init__(time_major=time_major) + + def initialize(self, inputs, sequence_length=None, auxiliary_inputs=None): + if auxiliary_inputs is None: + maybe_concatenated_inputs = inputs + else: + inputs = ops.convert_to_tensor(inputs) + auxiliary_inputs = ops.convert_to_tensor(auxiliary_inputs) + maybe_concatenated_inputs = nest.map_structure( + lambda x, y: array_ops.concat((x, y), -1), inputs, auxiliary_inputs) + if not self.time_major: + auxiliary_inputs = nest.map_structure(_transpose_batch_time, + auxiliary_inputs) + if auxiliary_inputs is not None: + self._auxiliary_input_tas = nest.map_structure(_unstack_ta, + auxiliary_inputs) + else: + self._auxiliary_input_tas = None + + return super(ScheduledOutputTrainingSampler, self).initialize( + maybe_concatenated_inputs, sequence_length=sequence_length) + + def sample(self, time, outputs, state): + del state + return bernoulli_sample( + probs=self.sampling_probability, + sample_shape=self.batch_size, + seed=self.seed) + + def next_inputs(self, time, outputs, state, sample_ids): + (finished, base_next_inputs, state) = ( + super(ScheduledOutputTrainingSampler, self).next_inputs( + time=time, outputs=outputs, state=state, sample_ids=sample_ids)) + sample_ids = math_ops.cast(sample_ids, dtypes.bool) + + def maybe_sample(): + """Perform scheduled sampling.""" + + def maybe_concatenate_auxiliary_inputs(outputs_, indices=None): + """Concatenate outputs with auxiliary inputs, if they exist.""" + if self._auxiliary_input_tas is None: + return outputs_ + + next_time = time + 1 + auxiliary_inputs = nest.map_structure(lambda ta: ta.read(next_time), + self._auxiliary_input_tas) + if indices is not None: + auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices) + return nest.map_structure(lambda x, y: array_ops.concat((x, y), -1), + outputs_, auxiliary_inputs) + + if self.next_inputs_fn is None: + return array_ops.where(sample_ids, + maybe_concatenate_auxiliary_inputs(outputs), + base_next_inputs) + + where_sampling = math_ops.cast(array_ops.where(sample_ids), dtypes.int32) + where_not_sampling = math_ops.cast( + array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32) + outputs_sampling = array_ops.gather_nd(outputs, where_sampling) + inputs_not_sampling = array_ops.gather_nd(base_next_inputs, + where_not_sampling) + sampled_next_inputs = maybe_concatenate_auxiliary_inputs( + self.next_inputs_fn(outputs_sampling), where_sampling) + + base_shape = array_ops.shape(base_next_inputs) + return (array_ops.scatter_nd( + indices=where_sampling, updates=sampled_next_inputs, shape=base_shape) + + array_ops.scatter_nd( + indices=where_not_sampling, + updates=inputs_not_sampling, + shape=base_shape)) + + all_finished = math_ops.reduce_all(finished) + no_samples = math_ops.logical_not(math_ops.reduce_any(sample_ids)) + next_inputs = control_flow_ops.cond( + math_ops.logical_or(all_finished, no_samples), lambda: base_next_inputs, + maybe_sample) + return (finished, next_inputs, state) + + +class GreedyEmbeddingSampler(Sampler): + """A sampler for use during inference. + + Uses the argmax of the output (treated as logits) and passes the + result through an embedding layer to get the next input. + """ + + def __init__(self, embedding_fn=None): + """Initializer. + + Args: + embedding_fn: A optional callable that takes a vector tensor of `ids` + (argmax ids), or the `params` argument for `embedding_lookup`. The + returned tensor will be passed to the decoder input. Default to use + `embedding_ops.embedding_lookup`. + """ + if embedding_fn is None or callable(embedding_fn): + self.embedding_fn = embedding_fn + else: + raise ValueError("embedding_fn is expected to be a callable, got %s" % + type(embedding_fn)) + self._batch_size = None + + @property + def batch_size(self): + if self._batch_size is None: + raise ValueError("batch_size accessed before initialize was called") + return self._batch_size + + @property + def sample_ids_shape(self): + return tensor_shape.TensorShape([]) + + @property + def sample_ids_dtype(self): + return dtypes.int32 + + def initialize(self, embedding, start_tokens=None, end_token=None): + """Initialize the GreedyEmbeddingSampler. + + Args: + embedding: tensor that contains embedding states matrix. It will be used + to generate generate outputs with start_tokens and end_tokens. The + embedding will be ignored if the embedding_fn has been provided at + __init__(). + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + + Returns: + Tuple of two items: `(finished, self.start_inputs)`. + Raises: + ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a + scalar. + """ + if self.embedding_fn is None: + self.embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self.start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + self.end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self.start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._batch_size = array_ops.size(start_tokens) + if self.end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + self.start_inputs = self.embedding_fn(self.start_tokens) + + finished = array_ops.tile([False], [self._batch_size]) + return (finished, self.start_inputs) + + def sample(self, time, outputs, state): + """sample for GreedyEmbeddingHelper.""" + del time, state # unused by sample_fn + # Outputs are logits, use argmax to get the most probable id + if not isinstance(outputs, ops.Tensor): + raise TypeError( + "Expected outputs to be a single Tensor, got: %s" % type(outputs)) + sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32) + return sample_ids + + def next_inputs(self, time, outputs, state, sample_ids): + """next_inputs_fn for GreedyEmbeddingHelper.""" + del time, outputs # unused by next_inputs_fn + finished = math_ops.equal(sample_ids, self.end_token) + all_finished = math_ops.reduce_all(finished) + next_inputs = control_flow_ops.cond( + all_finished, + # If we're finished, the next_inputs value doesn't matter + lambda: self.start_inputs, + lambda: self.embedding_fn(sample_ids)) + return (finished, next_inputs, state) + + +class SampleEmbeddingSampler(GreedyEmbeddingSampler): + """A sampler for use during inference. + + Uses sampling (from a distribution) instead of argmax and passes the + result through an embedding layer to get the next input. + """ + + def __init__(self, embedding_fn=None, softmax_temperature=None, seed=None): + """Initializer. + + Args: + embedding_fn: (Optional) A callable that takes a vector tensor of `ids` + (argmax ids), or the `params` argument for `embedding_lookup`. The + returned tensor will be passed to the decoder input. + softmax_temperature: (Optional) `float32` scalar, value to divide the + logits by before computing the softmax. Larger values (above 1.0) result + in more random samples, while smaller values push the sampling + distribution towards the argmax. Must be strictly greater than 0. + Defaults to 1.0. + seed: (Optional) The sampling seed. + + Raises: + ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a + scalar. + """ + super(SampleEmbeddingSampler, self).__init__(embedding_fn) + self.softmax_temperature = softmax_temperature + self.seed = seed + + def sample(self, time, outputs, state): + """sample for SampleEmbeddingHelper.""" + del time, state # unused by sample_fn + # Outputs are logits, we sample instead of argmax (greedy). + if not isinstance(outputs, ops.Tensor): + raise TypeError( + "Expected outputs to be a single Tensor, got: %s" % type(outputs)) + if self.softmax_temperature is None: + logits = outputs + else: + logits = outputs / self.softmax_temperature + + return categorical_sample(logits=logits, seed=self.seed) + + +class InferenceSampler(Sampler): + """A helper to use during inference with a custom sampling function.""" + + def __init__(self, + sample_fn, + sample_shape, + sample_dtype, + end_fn, + next_inputs_fn=None): + """Initializer. + + Args: + sample_fn: A callable that takes `outputs` and emits tensor `sample_ids`. + sample_shape: Either a list of integers, or a 1-D Tensor of type `int32`, + the shape of the each sample in the batch returned by `sample_fn`. + sample_dtype: the dtype of the sample returned by `sample_fn`. + end_fn: A callable that takes `sample_ids` and emits a `bool` vector + shaped `[batch_size]` indicating whether each sample is an end token. + next_inputs_fn: (Optional) A callable that takes `sample_ids` and returns + the next batch of inputs. If not provided, `sample_ids` is used as the + next batch of inputs. + """ + self.sample_fn = sample_fn + self.sample_shape = tensor_shape.TensorShape(sample_shape) + self.sample_dtype = sample_dtype + self.end_fn = end_fn + self.next_inputs_fn = next_inputs_fn + self._batch_size = None + + @property + def batch_size(self): + if self._batch_size is None: + raise ValueError("batch_size accessed before initialize was called") + return self._batch_size + + @property + def sample_ids_shape(self): + return self.sample_shape + + @property + def sample_ids_dtype(self): + return self.sample_dtype + + def initialize(self, start_inputs): + self.start_inputs = ops.convert_to_tensor(start_inputs, name="start_inputs") + self._batch_size = array_ops.shape(start_inputs)[0] + finished = array_ops.tile([False], [self._batch_size]) + return (finished, self.start_inputs) + + def sample(self, time, outputs, state): + del time, state # unused by sample + return self.sample_fn(outputs) + + def next_inputs(self, time, outputs, state, sample_ids): + del time, outputs # unused by next_inputs + if self.next_inputs_fn is None: + next_inputs = sample_ids + else: + next_inputs = self.next_inputs_fn(sample_ids) + finished = self.end_fn(sample_ids) + return (finished, next_inputs, state) + + +# The following sample functions (_call_sampler, bernoulli_sample, +# categorical_sample) mimic TensorFlow Probability distribution semantics. +def _call_sampler(sample_n_fn, sample_shape, name=None): + """Reshapes vector of samples.""" + with ops.name_scope(name, "call_sampler", values=[sample_shape]): + sample_shape = ops.convert_to_tensor( + sample_shape, dtype=dtypes.int32, name="sample_shape") + # Ensure sample_shape is a vector (vs just a scalar). + pad = math_ops.cast( + math_ops.equal(array_ops.rank(sample_shape), 0), dtypes.int32) + sample_shape = array_ops.reshape( + sample_shape, + array_ops.pad( + array_ops.shape(sample_shape), + paddings=[[pad, 0]], + constant_values=1)) + samples = sample_n_fn(math_ops.reduce_prod(sample_shape)) + batch_event_shape = array_ops.shape(samples)[1:] + final_shape = array_ops.concat([sample_shape, batch_event_shape], 0) + return array_ops.reshape(samples, final_shape) + + +def bernoulli_sample(probs=None, + logits=None, + dtype=dtypes.int32, + sample_shape=(), + seed=None): + """Samples from Bernoulli distribution.""" + if probs is None: + probs = math_ops.sigmoid(logits, name="probs") + else: + probs = ops.convert_to_tensor(probs, name="probs") + batch_shape_tensor = array_ops.shape(probs) + + def _sample_n(n): + """Sample vector of Bernoullis.""" + new_shape = array_ops.concat([[n], batch_shape_tensor], 0) + uniform = random_ops.random_uniform(new_shape, seed=seed, dtype=probs.dtype) + return math_ops.cast(math_ops.less(uniform, probs), dtype) + + return _call_sampler(_sample_n, sample_shape) + + +def categorical_sample(logits, dtype=dtypes.int32, sample_shape=(), seed=None): + """Samples from categorical distribution.""" + logits = ops.convert_to_tensor(logits, name="logits") + event_size = array_ops.shape(logits)[-1] + batch_shape_tensor = array_ops.shape(logits)[:-1] + + def _sample_n(n): + """Sample vector of categoricals.""" + if logits.shape.ndims == 2: + logits_2d = logits + else: + logits_2d = array_ops.reshape(logits, [-1, event_size]) + sample_dtype = dtypes.int64 if logits.dtype.size > 4 else dtypes.int32 + draws = random_ops.multinomial( + logits_2d, n, seed=seed, output_dtype=sample_dtype) + draws = array_ops.reshape( + array_ops.transpose(draws), + array_ops.concat([[n], batch_shape_tensor], 0)) + return math_ops.cast(draws, dtype) + + return _call_sampler(_sample_n, sample_shape) + + +def _unstack_ta(inp): + return tensor_array_ops.TensorArray( + dtype=inp.dtype, + size=array_ops.shape(inp)[0], + element_shape=inp.get_shape()[1:]).unstack(inp) diff --git a/tensorflow/contrib/session_bundle/exporter.py b/tensorflow/contrib/session_bundle/exporter.py index 08983337fccc138d40eb959cecc5bf9e47cf6cac..f3efd292cf5acba4319c8a5545a7f70fae4b5ce1 100644 --- a/tensorflow/contrib/session_bundle/exporter.py +++ b/tensorflow/contrib/session_bundle/exporter.py @@ -304,10 +304,10 @@ class Exporter(object): def parser(path): if os.name == "nt": match = re.match( - "^" + export_dir_base.replace("\\", "/") + "/(\\d{8})$", + r"^" + export_dir_base.replace("\\", "/") + r"/(\d{8})$", path.path.replace("\\", "/")) else: - match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path) + match = re.match(r"^" + export_dir_base + r"/(\d{8})$", path.path) if not match: return None return path._replace(export_version=int(match.group(1))) diff --git a/tensorflow/contrib/session_bundle/gc_test.py b/tensorflow/contrib/session_bundle/gc_test.py index 8faf3ef3d4cd7ee0096265283070e25d06782254..02725bb1cbb4ef9ace29dcc58f6d23fb241d96b2 100644 --- a/tensorflow/contrib/session_bundle/gc_test.py +++ b/tensorflow/contrib/session_bundle/gc_test.py @@ -104,7 +104,7 @@ class GcTest(test_util.TensorFlowTestCase): # create a simple parser that pulls the export_version from the directory. def parser(path): - match = re.match("^" + base_dir + "/(\\d+)$", path.path) + match = re.match(r"^" + base_dir + r"/(\d+)$", path.path) if not match: return None return path._replace(export_version=int(match.group(1))) diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index 1b2b6acacca838f95cb758ae88f79263993ca69e..c63a3ca19b6a70cf7776c7fce4e0291ee94b775c 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -32,7 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import map_fn from tensorflow.python.ops import image_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops @@ -396,8 +396,8 @@ class Image(ItemHandler): image_format = keys_to_tensors[self._format_key] if self._repeated: - return functional_ops.map_fn(lambda x: self._decode(x, image_format), - image_buffer, dtype=self._dtype) + return map_fn.map_fn(lambda x: self._decode(x, image_format), + image_buffer, dtype=self._dtype) else: return self._decode(image_buffer, image_format) diff --git a/tensorflow/contrib/slim/python/slim/nets/BUILD b/tensorflow/contrib/slim/python/slim/nets/BUILD index 8bbdf96384683c68648367c6433eeb89c64c22bf..e9595d1b324dbd3d570d2407a6620c5295b15548 100644 --- a/tensorflow/contrib/slim/python/slim/nets/BUILD +++ b/tensorflow/contrib/slim/python/slim/nets/BUILD @@ -115,9 +115,9 @@ py_library( py_test( name = "inception_v1_test", - size = "large", + size = "medium", srcs = ["inception_v1_test.py"], - shard_count = 3, + shard_count = 8, srcs_version = "PY2AND3", deps = [ ":inception_v1", @@ -135,9 +135,9 @@ py_test( py_test( name = "inception_v2_test", - size = "large", + size = "medium", srcs = ["inception_v2_test.py"], - shard_count = 3, + shard_count = 8, srcs_version = "PY2AND3", deps = [ ":inception_v2", @@ -155,9 +155,9 @@ py_test( py_test( name = "inception_v3_test", - size = "large", + size = "medium", srcs = ["inception_v3_test.py"], - shard_count = 3, + shard_count = 8, srcs_version = "PY2AND3", deps = [ ":inception_v3", @@ -233,8 +233,9 @@ py_library( py_test( name = "resnet_v1_test", - size = "large", + size = "medium", srcs = ["resnet_v1_test.py"], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":resnet_utils", @@ -268,8 +269,9 @@ py_library( py_test( name = "resnet_v2_test", - size = "large", + size = "medium", srcs = ["resnet_v2_test.py"], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":resnet_utils", diff --git a/tensorflow/contrib/sparsemax/BUILD b/tensorflow/contrib/sparsemax/BUILD index d7ba754f701d4b433e35ad8396eae7ee6132b97f..ed4eca1a60a6f0ccf629d8aa7906c02092e25ba0 100644 --- a/tensorflow/contrib/sparsemax/BUILD +++ b/tensorflow/contrib/sparsemax/BUILD @@ -49,6 +49,9 @@ cuda_py_tests( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + tags = [ + "oss_serial", + ], ) cuda_py_tests( @@ -64,4 +67,7 @@ cuda_py_tests( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + tags = [ + "oss_serial", + ], ) diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index f88b03ec4c2b1f250091594ea12d7d1862029fa2..7dd52df6b68caea6111813837ba1e872acbeccdb 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -4,17 +4,14 @@ exports_files([ "LICENSE", ]) -load( - "//tensorflow:tensorflow.bzl", - "py_test", - "tf_gen_op_wrapper_py", -) +load("//tensorflow:tensorflow.bzl", "py_test") py_test( name = "summary_ops_test", srcs = ["summary_ops_test.py"], srcs_version = "PY2AND3", deps = [ + ":summary", ":summary_test_util", "//tensorflow/python:array_ops", "//tensorflow/python:errors", @@ -22,7 +19,6 @@ py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", "//tensorflow/python:state_ops", - "//tensorflow/python:summary_ops_v2", "//tensorflow/python:training", "//tensorflow/python/eager:function", "//tensorflow/python/eager:test", @@ -35,6 +31,7 @@ py_test( srcs = ["summary_ops_graph_test.py"], srcs_version = "PY2AND3", deps = [ + ":summary", ":summary_test_util", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -43,7 +40,6 @@ py_test( "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:summary_ops_v2", "//tensorflow/python:training", "//tensorflow/python:variables", "@six_archive//:six", diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py index 807741e05f92f6b666c175269742dc1af50c0054..8e13f7f56b23e47f046120b285b1519c6371ddab 100644 --- a/tensorflow/contrib/summary/summary_ops_graph_test.py +++ b/tensorflow/contrib/summary/summary_ops_graph_test.py @@ -22,6 +22,7 @@ import time import six +from tensorflow.contrib.summary import summary as summary_ops from tensorflow.contrib.summary import summary_test_util from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 @@ -32,7 +33,6 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 10e4556dacbc17ec02c2bd698389b04d517d7076..27bfdeb3601f4fdb9897feee509b06d5e8f9b873 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -25,6 +25,7 @@ import sqlite3 import numpy as np import six +from tensorflow.contrib.summary import summary as summary_ops from tensorflow.contrib.summary import summary_test_util from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 @@ -36,7 +37,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.platform import gfile from tensorflow.python.training import training_util diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 398ac314f4b520610ec100273b37c33bc4b5b43a..583bbf97c57cf263f65bc3b0a56b32cc2dce5482 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -537,8 +537,9 @@ py_library( py_test( name = "random_forest_test", - size = "large", + size = "medium", srcs = ["client/random_forest_test.py"], + shard_count = 6, srcs_version = "PY2AND3", tags = [ "noasan", diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index d8236a0a6fa6d0d0e383e454eb0146bb10b6f49d..0d87cea9fbaa8fe28b55ec996414a568d39efee3 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -50,9 +50,10 @@ def _accuracy(predictions, targets, weights=None): def _r2(probabilities, targets, weights=None): targets = math_ops.to_float(targets) y_mean = math_ops.reduce_mean(targets, 0) - squares_total = math_ops.reduce_sum(math_ops.square(targets - y_mean), 0) + squares_total = math_ops.reduce_sum( + math_ops.squared_difference(targets, y_mean), 0) squares_residuals = math_ops.reduce_sum( - math_ops.square(targets - probabilities), 0) + math_ops.squared_difference(targets, probabilities), 0) score = 1 - math_ops.reduce_sum(squares_residuals / squares_total) return metrics.mean(score, weights=weights) diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc index b9aad36f3d25b9fb7b8b525be54fb7a39394b373..76b1d2b4da269cda71f5b49878f2933d7d9b5776 100644 --- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc @@ -304,7 +304,7 @@ class TraverseTreeV4Op : public OpKernel { auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; const int64 costPerTraverse = 500; - auto traverse = [this, &set_leaf_ids, &data_set, decision_tree_resource, + auto traverse = [&set_leaf_ids, &data_set, decision_tree_resource, num_data](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc index fe2c91c1047fe56710b1a86b2fa3206caf6ff3bc..0243f106814511c1b53a5aacb830b845214a00a3 100644 --- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc @@ -307,7 +307,7 @@ class ProcessInputOp : public OpKernel { // from a digits run on local desktop. Heuristics might be necessary // if it really matters that much. const int64 costPerUpdate = 1000; - auto update = [this, &target, &leaf_ids_tensor, &num_targets, &data_set, + auto update = [&target, &leaf_ids_tensor, &num_targets, &data_set, fertile_stats_resource, &locks, &set_lock, &ready_to_split, num_data](int64 start, int64 end) { CHECK(start <= end); @@ -317,7 +317,7 @@ class ProcessInputOp : public OpKernel { static_cast(end), &ready_to_split); }; - auto update_collated = [this, &target, &num_targets, fertile_stats_resource, + auto update_collated = [&target, &num_targets, fertile_stats_resource, tree_resource, &leaf_examples, &set_lock, &ready_to_split, &data_set, num_leaves](int64 start, int64 end) { diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h index e04eb60f9b27cfd8b6b4e1502594d4d310ae55cc..774da472f1543f938d1b607ebdef008f7b540211 100644 --- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h +++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h @@ -18,10 +18,10 @@ #include #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/random/distribution_sampler.h" #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/lib/strings/strcat.h" diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h index d3edb43733761a906c6e5bf8b65f76e3e1ae56fc..3100a5a0e5da1103b61bd089cd433721686b9e72 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h @@ -32,7 +32,7 @@ class DecisionTreeResource : public ResourceBase { // Constructor. explicit DecisionTreeResource(const TensorForestParams& params); - string DebugString() override { + string DebugString() const override { return strings::StrCat("DecisionTree[size=", decision_tree_->decision_tree().nodes_size(), "]"); } diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h index eea0be27caf0a022ba7acaacd359c75a2df4eedb..44f2b3f473b9eced06bd800b9cf0a5a0825ec3eb 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h @@ -40,7 +40,7 @@ class FertileStatsResource : public ResourceBase { model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_); } - string DebugString() override { return "FertileStats"; } + string DebugString() const override { return "FertileStats"; } void ExtractFromProto(const FertileStats& stats); diff --git a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py index 290c16fe3966791ea78986539750caf938a37322..40bf7081a3f22dfd68fd46f0f61695ee9ca7863b 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py @@ -35,7 +35,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking _model_ops = loader.load_op_library( diff --git a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py index 9184198cd4c8fd2a7609714d094d5ef2b6868658..80afcfb251f4d6455a9eb8ba5df4a6e43d2feb1c 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py @@ -32,7 +32,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking _stats_ops = loader.load_op_library( diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 784acce444a8d0c066f1b7ae6c1b5d7d65405549..91b6d2614a8963c21e35c385411dc4c9956e3146 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -11,567 +11,54 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", - "tf_cc_test", - "tf_copts", "tf_cuda_library", - "tf_custom_op_library", "tf_custom_op_library_additional_deps", - "tf_gen_op_libs", - "tf_gen_op_wrapper_py", ) -load("//tensorflow:tensorflow.bzl", "cuda_py_test") -load("//tensorflow:tensorflow.bzl", "cuda_py_tests") -load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load( "@local_config_tensorrt//:build_defs.bzl", "if_tensorrt", ) -exports_files(glob([ - "test/testdata/*", -])) - -tf_cuda_cc_test( - name = "tensorrt_test_cc", - size = "small", - srcs = ["tensorrt_test.cc"], - tags = [ - "no_windows", - "nomac", - ], - deps = [ - "//tensorflow/core:gpu_init", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ] + if_tensorrt([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_tensorrt//:nv_infer", - ]), -) - -tf_custom_op_library( - name = "python/ops/_trt_engine_op.so", - srcs = [ - "ops/trt_engine_op.cc", - ], - deps = [ - ":trt_shape_function", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) - tf_cuda_library( name = "trt_shape_function", srcs = ["shape_fn/trt_shfn.cc"], hdrs = ["shape_fn/trt_shfn.h"], visibility = ["//visibility:public"], deps = [ - ":trt_logging", - ":trt_plugins", + "//tensorflow/compiler/tf2tensorrt:trt_logging", + "//tensorflow/compiler/tf2tensorrt:trt_plugins", ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + "@local_config_tensorrt//:tensorrt", ]) + tf_custom_op_library_additional_deps(), ) -cc_library( - name = "trt_engine_op_kernel", - srcs = [ - "kernels/trt_engine_op.cc", - ], - hdrs = [ - "kernels/trt_engine_op.h", - ], - copts = tf_copts(), - visibility = ["//visibility:public"], - deps = [ - ":test_utils", - ":trt_allocator", - ":trt_conversion", - ":trt_logging", - ":trt_plugins", - ":trt_resources", - ":utils", - "//tensorflow/core:gpu_headers_lib", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:stream_executor_headers_lib", - "//tensorflow/core/grappler/costs:graph_properties", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]) + tf_custom_op_library_additional_deps(), - # TODO(laigd): fix this by merging header file in cc file. - alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs -) - -tf_gen_op_libs( - op_lib_names = [ - "trt_engine_op", - ], -) - -tf_cuda_library( - name = "trt_logging", - srcs = ["log/trt_logger.cc"], - hdrs = ["log/trt_logger.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) - -tf_gen_op_wrapper_py( - name = "trt_engine_op", - deps = [ - ":trt_engine_op_op_lib", - ":trt_logging", - ":trt_shape_function", - ], -) - -tf_custom_op_py_library( - name = "trt_engine_op_loader", - srcs = ["python/ops/trt_engine_op.py"], - dso = [ - ":python/ops/_trt_engine_op.so", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), - kernels = [ - ":trt_engine_op_kernel", - ":trt_engine_op_op_lib", - ":trt_shape_function", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:resources", - ], -) - py_library( name = "init_py", srcs = [ "__init__.py", "python/__init__.py", + "python/trt_convert.py", ], srcs_version = "PY2AND3", deps = [ - ":tf_trt_integration_test_base", - ":trt_convert_py", - ":trt_ops_py", - "//tensorflow/python:errors", - ], -) - -py_library( - name = "trt_ops_py", - srcs_version = "PY2AND3", - deps = [ - ":trt_engine_op", - ":trt_engine_op_loader", - ], -) - -py_library( - name = "trt_convert_py", - srcs = ["python/trt_convert.py"], - srcs_version = "PY2AND3", - deps = [ - ":wrap_conversion", - "//tensorflow/python:graph_util", - "//tensorflow/python:session", - "//tensorflow/python:tf_optimizer", - "//tensorflow/python/saved_model:builder", - "//tensorflow/python/saved_model:loader", - "//tensorflow/python/saved_model:tag_constants", - ], -) - -# TODO(aaroey): this wrapper has been causing troubles of double linking, so -# either get rid of it, or split to make it contain minimum dependencies. -tf_py_wrap_cc( - name = "wrap_conversion", - srcs = ["trt_conversion.i"], - copts = tf_copts(), - swig_includes = [ - "//tensorflow/python:platform/base.i", - ], - deps = [ - ":test_utils", - ":trt_conversion", - ":trt_engine_op_kernel", - "//third_party/python_runtime:headers", - ], -) - -tf_cuda_library( - name = "trt_resources", - srcs = [ - "resources/trt_int8_calibrator.cc", - "resources/trt_resource_manager.cc", - ], - hdrs = [ - "resources/trt_int8_calibrator.h", - "resources/trt_resource_manager.h", - "resources/trt_resources.h", + "//tensorflow/python/compiler/tensorrt:init_py", ], - deps = [ - ":trt_allocator", - ":trt_logging", - ":utils", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), ) -tf_cuda_library( - name = "trt_allocator", - srcs = ["resources/trt_allocator.cc"], - hdrs = ["resources/trt_allocator.h"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) +# The following rules forward the libraries that were moved in order to not +# break other internal targets. -tf_cc_test( - name = "trt_allocator_test", - size = "small", - srcs = ["resources/trt_allocator_test.cc"], - tags = [ - "no_windows", - "nomac", - ], - deps = [ - ":trt_allocator", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -# Library for the node-level conversion portion of TensorRT operation creation -tf_cuda_library( +alias( name = "trt_conversion", - srcs = [ - "convert/convert_graph.cc", - "convert/convert_nodes.cc", - "convert/trt_optimization_pass.cc", - ], - hdrs = [ - "convert/convert_graph.h", - "convert/convert_nodes.h", - "convert/trt_optimization_pass.h", - ], - deps = [ - ":segment", - ":test_utils", - ":trt_allocator", - ":trt_plugins", - ":trt_logging", - ":trt_resources", - ":utils", - "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler:utils", - "//tensorflow/core:framework", - "//tensorflow/core:framework_lite", - "//tensorflow/core:gpu_runtime", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:devices", - "//tensorflow/core/grappler/clusters:virtual_cluster", - "//tensorflow/core/grappler/costs:graph_properties", - "//tensorflow/core/grappler/optimizers:meta_optimizer", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]) + tf_custom_op_library_additional_deps(), -) - -tf_cuda_cc_test( - name = "convert_graph_test", - size = "medium", - srcs = ["convert/convert_graph_test.cc"], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], - deps = [ - ":trt_conversion", - "@com_google_googletest//:gtest", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:scope", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:direct_session", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) - -tf_cuda_cc_test( - name = "convert_nodes_test", - size = "medium", - srcs = ["convert/convert_nodes_test.cc"], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], - deps = [ - ":trt_logging", - ":trt_conversion", - ":trt_plugins", - "@com_google_googletest//:gtest", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/core/grappler/costs:graph_properties", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ] + if_tensorrt([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_tensorrt//:nv_infer", - ]), -) - -# Library for the segmenting portion of TensorRT operation creation -cc_library( - name = "segment", - srcs = ["segment/segment.cc"], - hdrs = [ - "segment/segment.h", - "segment/union_find.h", - ], - deps = [ - "//tensorflow/core:graph", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:protos_all_cc", - "@protobuf_archive//:protobuf_headers", - ], -) - -tf_cc_test( - name = "segment_test", - size = "small", - srcs = ["segment/segment_test.cc"], - tags = [ - "no_windows", - "nomac", - ], - deps = [ - ":segment", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:scope", - "//tensorflow/core:core_cpu", - "//tensorflow/core:lib", - "//tensorflow/core:ops", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - -# Library for the plugin factory -tf_cuda_library( - name = "trt_plugins", - srcs = [ - "plugin/trt_plugin.cc", - "plugin/trt_plugin_factory.cc", - "plugin/trt_plugin_utils.cc", - ], - hdrs = [ - "plugin/trt_plugin.h", - "plugin/trt_plugin_factory.h", - "plugin/trt_plugin_utils.h", - ], - deps = [ - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) - -tf_cuda_cc_test( - name = "trt_plugin_factory_test", - size = "small", - srcs = ["plugin/trt_plugin_factory_test.cc"], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], - deps = [ - ":trt_plugins", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ] + if_tensorrt([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_tensorrt//:nv_infer", - ]), -) - -py_library( - name = "tf_trt_integration_test_base", - srcs = ["test/tf_trt_integration_test_base.py"], - deps = [ - ":trt_convert_py", - ":trt_ops_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - ], -) - -cuda_py_test( - name = "trt_convert_test", - srcs = ["python/trt_convert_test.py"], - additional_deps = [ - ":trt_convert_py", - ":trt_ops_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:graph_util", - "//tensorflow/python/saved_model:builder", - "//tensorflow/python/saved_model:loader", - "//tensorflow/python/saved_model:signature_constants", - "//tensorflow/python/saved_model:signature_def_utils", - "//tensorflow/python/saved_model:tag_constants", - "//tensorflow/python/saved_model:utils", - "//tensorflow/python/tools:freeze_graph_lib", - "//tensorflow/python/tools:saved_model_utils", - ], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], -) - -cuda_py_tests( - name = "tf_trt_integration_test", - srcs = [ - "test/base_test.py", - "test/batch_matmul_test.py", - "test/biasadd_matmul_test.py", - "test/binary_tensor_weight_broadcast_test.py", - "test/concatenation_test.py", - "test/const_broadcast_test.py", - "test/manual_test.py", - "test/memory_alignment_test.py", - "test/multi_connection_neighbor_engine_test.py", - "test/neighboring_engine_test.py", - "test/quantization_test.py", - "test/rank_two_test.py", - "test/reshape_transpose_test.py", - "test/vgg_block_nchw_test.py", - "test/vgg_block_test.py", - ], - additional_deps = [ - ":tf_trt_integration_test_base", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - ], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], -) - -cuda_py_tests( - name = "tf_trt_integration_test_no_oss", - srcs = [ - "test/unary_test.py", - ], - additional_deps = [ - ":tf_trt_integration_test_base", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - ], - tags = [ - "no_cuda_on_cpu_tap", - "no_oss", # TODO(b/117274186): re-enable in OSS after crash fixed - "no_pip", # TODO(b/117274186): re-enable in OSS after crash fixed - "no_windows", - "nomac", - ], + actual = "//tensorflow/compiler/tf2tensorrt:trt_conversion", ) -cuda_py_test( - name = "quantization_mnist_test", - srcs = ["test/quantization_mnist_test.py"], - additional_deps = [ - ":tf_trt_integration_test_base", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python/keras:keras", - "//tensorflow/python/estimator:estimator", - ], - data = [ - "test/testdata/checkpoint", - "test/testdata/model.ckpt-46900.data-00000-of-00001", - "test/testdata/model.ckpt-46900.index", - ], - tags = [ - "no_cuda_on_cpu_tap", - "no_pip", - "no_tap", # It is not able to download the mnist data. - "no_windows", - "nomac", - ], +alias( + name = "trt_op_kernels", + actual = "//tensorflow/compiler/tf2tensorrt:trt_op_kernels", ) -cc_library( - name = "utils", - srcs = ["convert/utils.cc"], - hdrs = ["convert/utils.h"], - copts = tf_copts(), - deps = [ - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "test_utils", - srcs = ["test/utils.cc"], - hdrs = ["test/utils.h"], - deps = [ - "//tensorflow/core:lib", - "@com_googlesource_code_re2//:re2", - ], +alias( + name = "trt_engine_op_op_lib", + actual = "//tensorflow/compiler/tf2tensorrt:trt_engine_op_op_lib", ) diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md deleted file mode 100644 index caf8b6db0dc0a220d593f9c0afc9464ca51a1e05..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/README.md +++ /dev/null @@ -1,29 +0,0 @@ -# Using TensorRT in TensorFlow - -This module provides necessary bindings and introduces TRT_engine_op operator -that wraps a subgraph in TensorRT. This is still a work in progress but should -be useable with most common graphs. - -## Compilation - -In order to compile the module, you need to have a local TensorRT installation -(libnvinfer.so and respective include files). During the configuration step, -TensorRT should be enabled and installation path should be set. If installed -through package managers (deb,rpm), configure script should find the necessary -components from the system automatically. If installed from tar packages, user -has to set path to location where the library is installed during configuration. - -```shell -bazel build --config=cuda --config=opt //tensorflow/tools/pip_package:build_pip_package -bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/ -``` - -After the installation of tensorflow package, TensorRT transformation will be -available. An example use can be found in test/test_tftrt.py script - -## Installing TensorRT 3.0.4 - -In order to make use of TensorRT integration, you will need a local installation -of TensorRT 3.0.4 from the [NVIDIA Developer website](https://developer.nvidia.com/tensorrt). -Installation instructions for compatibility with TensorFlow are provided on the -[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide. diff --git a/tensorflow/contrib/tensorrt/__init__.py b/tensorflow/contrib/tensorrt/__init__.py index 140ad4828208ae4844a49bf664955b50cd9e51cd..fd551d70b4385b14b84b7b98a6d16b0c03733d38 100644 --- a/tensorflow/contrib/tensorrt/__init__.py +++ b/tensorflow/contrib/tensorrt/__init__.py @@ -18,18 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import errors - -# pylint: disable=unused-import,wildcard-import,g-import-not-at-top -try: - from tensorflow.contrib.tensorrt.python import * -except errors.NotFoundError as e: - no_trt_message = ( - '**** Failed to initialize TensorRT. This is either because the TensorRT' - ' installation path is not in LD_LIBRARY_PATH, or because you do not have' - ' it installed. If not installed, please go to' - ' https://developer.nvidia.com/tensorrt to download and install' - ' TensorRT ****') - print(no_trt_message) - raise e -# pylint: enable=unused-import,wildcard-import,g-import-not-at-top +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.tensorrt.python import * +# pylint: enable=unused-import,wildcard-import diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index 69058c5826822c519a69d50860c06b8ab3ec6578..0a2cf105baf5efb62d0c535c1f2d081973ec0ea3 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -45,10 +45,10 @@ tf_custom_op_library( "inc_op_kernel.cu.cc", ], deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", + "//tensorflow/compiler/tf2tensorrt:trt_plugins", "//tensorflow/core:framework_lite", ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + "@local_config_tensorrt//:tensorrt", ]), ) @@ -64,10 +64,10 @@ tf_kernel_library( "inc_op_kernel.cu.cc", ], deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", + "//tensorflow/compiler/tf2tensorrt:trt_plugins", "//tensorflow/core:stream_executor_headers_lib", ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + "@local_config_tensorrt//:tensorrt", ]) + tf_custom_op_library_additional_deps(), ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc index 8d4c893af56689185da72398919e2241d451594b..7c9075142a02546ddd580e861ac87cb86badd739 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index 189e9c939b9ffd4450f7ba95fe1abdbbc049b430..fb048d7b19da0f010ed918b147013b20d37ed0dd 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h deleted file mode 100644 index b545f497f32d5a1a6960b748467ca189b7debf6c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ /dev/null @@ -1,145 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ - -#include -#include - -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/platform/mutex.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT -#include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/include/NvInfer.h" - -namespace tensorflow { -namespace tensorrt { -struct TRTInt8Calibrator; -class TRTCalibrationResource; -class AsyncHelper; -// TODO(Sami): Remove this file? - -// This OP can construct TRTEngine on the fly and if construction of engine -// fails, executes equivalent subgraph as a TensorFlow function. -class TRTEngineOp : public AsyncOpKernel { - public: - explicit TRTEngineOp(OpKernelConstruction* context); - - void ComputeAsync(OpKernelContext* context, - AsyncOpKernel::DoneCallback done) override; - ~TRTEngineOp(); - - private: - // Execute calibration - void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper); - - // Construct a function handle for executing native funcdef graph - Status ConstructFunctionHandle(OpKernelContext* ctx); - - // Execute replaced native segment as function Op. - void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); - - // Execute the tensorrt engine. Returns whether we need to retry by running - // the native segment. - bool ExecuteTrtEngine(OpKernelContext* ctx, const int num_batch, - nvinfer1::ICudaEngine* trt_engine_ptr, - nvinfer1::IExecutionContext* trt_execution_context_ptr); - - // Allocate necessary resources for calibration - Status AllocateCalibrationResources(OpKernelContext* ctx, - TRTCalibrationResource** cr); - - // TODO(samikama): context should go to a resource manager! - typedef std::pair, - TrtUniquePtrType> - EngineCtxPair; - EngineCtxPair& GetEngine(int batch_size, OpKernelContext* ctx); - - // Return engine batch closest to input batch. - int GetEngineBatch(OpKernelContext* ctx); - - nvinfer1::IGpuAllocator* GetAllocator(OpKernelContext* ctx); - - // map to keep engines and their execution context for given batch size. - std::unordered_map engine_map_; - std::vector input_nodes_; - std::vector output_nodes_; - - // keep device allocator for TRT. - std::unique_ptr allocator_; - - // serialized protobuf segment or trt engine depending on static_engine_ flag. - string serialized_segment_; - - // Name of the function for TF native execution of the segment. - string funcdef_name_; - - // GraphDef representation of the segment. - GraphDef segment_graph_; - - // Lookup table for temporary staging areas of input tensors for calibration. - std::unordered_map> device_buffers_; - - // Temporary staging areas for calibration inputs. - std::vector dev_tensors_; - - // Engine Precision mode. - int precision_mode_; - - // Whether engine is constructed during the conversion or needs to be - // constructed from protobuf segment. - bool static_engine_; - - // Whether to calibrate INT8 engine. - bool calibration_mode_; - - // Whether non-batch ranks of the inputs are assumed to be fixed or not for - // engine construction. - bool fixed_input_size_; - - // Batches of the cached engines - std::vector cached_engine_batches_; - - // Maximum number of cached engines - int max_cached_engines_; - - int64 workspace_size_; - mutex engine_mutex_; - FunctionLibraryRuntime::Handle native_func_; - - // The finalized calibrator for inference. - std::unique_ptr calibrator_; - - // If true, create calibration graph for INT8 mode. Otherwise, we are using - // user-provided quantization ranges. - bool use_calibration_; -}; - -} // namespace tensorrt -} // namespace tensorflow - -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA - -#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py index 7cdfe2b1a612be2eec473d806d0eb44b611ca68a..0cae401023e7d3e3780b9dd2e2a92c9fd0e92db8 100644 --- a/tensorflow/contrib/tensorrt/python/__init__.py +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -19,12 +19,6 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long -from tensorflow.contrib.tensorrt.python.ops import trt_engine_op -from tensorflow.contrib.tensorrt.python.trt_convert import add_test_value from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph -from tensorflow.contrib.tensorrt.python.trt_convert import clear_test_values from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph -from tensorflow.contrib.tensorrt.python.trt_convert import enable_test_value -from tensorflow.contrib.tensorrt.python.trt_convert import get_test_value -from tensorflow.contrib.tensorrt.python.trt_convert import is_tensorrt_enabled # pylint: enable=unused-import,line-too-long diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 203b2697babe32b45523109708cbf062dceee33b..4a959378138dec6f1c1a3f490704d7aebeae9b47 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -18,404 +18,41 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six as _six -# pylint: disable=unused-import,line-too-long -from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value -from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert -from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values -from tensorflow.contrib.tensorrt.wrap_conversion import enable_test_value -from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version -from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version -from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value -from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled -# pylint: enable=unused-import,line-too-long -from tensorflow.core.framework import graph_pb2 -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import meta_graph_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python.client import session -from tensorflow.python.framework import errors_impl as _impl -from tensorflow.python.framework import graph_util -from tensorflow.python.framework import importer -from tensorflow.python.framework import ops -from tensorflow.python.grappler import tf_optimizer -from tensorflow.python.platform import tf_logging -from tensorflow.python.saved_model import builder -from tensorflow.python.saved_model import loader_impl -from tensorflow.python.saved_model import tag_constants -from tensorflow.python.training import saver - -if _six.PY2: - _to_bytes = lambda s: s - _to_string = lambda s: s -else: - _to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape") - _to_string = lambda s: s.decode("utf-8") - - -class TrtPrecisionMode(object): - FP32 = "FP32" - FP16 = "FP16" - INT8 = "INT8" - - @staticmethod - def supported_precision_modes(): - return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8] - - -def get_tensorrt_rewriter_config(rewriter_config=None, - max_batch_size=1, - max_workspace_size_bytes=2 << 20, - precision_mode=TrtPrecisionMode.FP32, - minimum_segment_size=3, - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batch_sizes=None, - use_calibration=True): - """Returns a RewriterConfig proto for TRT transformation. - - Args: - rewriter_config: a template RewriterConfig proto used to create a - TRT-enabled RewriterConfig. If None, it will use a default one. - max_batch_size: max size for the input batch - max_workspace_size_bytes: the maximum GPU temporary memory which the TRT - engine can use at execution time. This corresponds to the 'workspaceSize' - parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). - precision_mode: one of TrtPrecisionMode.supported_precision_modes(). - minimum_segment_size: the minimum number of nodes required for a subgraph to - be replaced by TRTEngineOp. - is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT - network and engine at run time. - maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. - If the number of cached engines is already at max but none of them can - serve the input, the TRTEngineOp will fall back to run the TF function - based on which the TRTEngineOp is created. - cached_engine_batch_sizes: a list of batch sizes used to create cached - engines, only used when is_dynamic_op is True. The length of the list - should be smaller than maximum_cached_engines, and the dynamic TRT op will - use this list to determine the batch sizes of the cached engines, instead - of making the decision on the fly. This is useful when we know the most - common batch size(s) the application is going to generate. - use_calibration: this argument is ignored if precision_mode is not INT8. If - set to True, a calibration graph will be created to calibrate the missing - ranges. The calibration graph must be converted to an inference graph - using calib_graph_to_infer_graph() after running calibration. if set to - False, quantization nodes will be expected for every tensor in the graph - (exlcuding those which will be fused). If a range is missing, an error - will occur. Please note that accuracy may be negatively affected if there - is a mismatch between which tensors TRT quantizes and which tensors were - trained with fake quantization. - - Returns: - A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler. - - Raises: - TypeError: if any of the parameters are of unexpected type. - ValueError: if any of the parameters are of unexpected value. - """ - if rewriter_config is not None and not isinstance( - rewriter_config, rewriter_config_pb2.RewriterConfig): - raise TypeError("rewriter_config should be a RewriterConfig proto.") - - rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig() - if rewriter_config is None: - # Layout optimizer may add Const nodes followed by Reshape nodes, thus we - # need to run constant folding again. - rewriter_config_with_trt.optimizers.extend( - ["constfold", "layout", "constfold"]) - rewriter_config_with_trt.meta_optimizer_iterations = ( - rewriter_config_pb2.RewriterConfig.ONE) - else: - rewriter_config_with_trt.CopyFrom(rewriter_config) - - if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes(): - raise ValueError(("precision mode '{}' is not supported." - "It should be one of {}").format( - precision_mode, - TrtPrecisionMode.supported_precision_modes)) - - optimizer = rewriter_config_with_trt.custom_optimizers.add() - optimizer.name = "TensorRTOptimizer" - optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size - optimizer.parameter_map["max_batch_size"].i = max_batch_size - optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op - optimizer.parameter_map[ - "max_workspace_size_bytes"].i = max_workspace_size_bytes - optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode) - optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines - if cached_engine_batch_sizes: - if not isinstance(cached_engine_batch_sizes, list): - raise TypeError("cached_engine_batch_sizes should be a list.") - if len(cached_engine_batch_sizes) > maximum_cached_engines: - raise ValueError("cached_engine_batch_sizes should not contain more than " - "maximum_cached_engines items.") - optimizer.parameter_map["cached_engine_batches"].list.i.extend( - cached_engine_batch_sizes) - optimizer.parameter_map["use_calibration"].b = use_calibration - return rewriter_config_with_trt - - -def create_inference_graph(input_graph_def, - outputs, - max_batch_size=1, - max_workspace_size_bytes=2 << 20, - precision_mode=TrtPrecisionMode.FP32, - minimum_segment_size=3, - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batch_sizes=None, - use_calibration=True, - input_saved_model_dir=None, - input_saved_model_tags=None, - output_saved_model_dir=None, - session_config=None): - """Python wrapper for the TRT transformation. - - Args: - input_graph_def: a GraphDef object containing a model to be transformed. If - set to None, the graph will be read from the SavedModel loaded from - input_saved_model_dir. - outputs: list of tensors or node names for the model outputs. Only used when - input_graph_def is not None. - max_batch_size: max size for the input batch. - max_workspace_size_bytes: the maximum GPU temporary memory which the TRT - engine can use at execution time. This corresponds to the 'workspaceSize' - parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). - precision_mode: one of TrtPrecisionMode.supported_precision_modes(). - minimum_segment_size: the minimum number of nodes required for a subgraph to - be replaced by TRTEngineOp. - is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT - network and engine at run time. - maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. - If the number of cached engines is already at max but none of them can - serve the input, the TRTEngineOp will fall back to run the TF function - based on which the TRTEngineOp is created. - cached_engine_batch_sizes: a list of batch sizes used to create cached - engines, only used when is_dynamic_op is True. The length of the list - should be smaller than maximum_cached_engines, and the dynamic TRT op will - use this list to determine the batch sizes of the cached engines, instead - of making the decision on the fly. This is useful when we know the most - common batch size(s) the application is going to generate. - use_calibration: this argument is ignored if precision_mode is not INT8. If - set to True, a calibration graph will be created to calibrate the missing - ranges. The calibration graph must be converted to an inference graph - using calib_graph_to_infer_graph() after running calibration. if set to - False, quantization nodes will be expected for every tensor in the graph - (exlcuding those which will be fused). If a range is missing, an error - will occur. Please note that accuracy may be negatively affected if there - is a mismatch between which tensors TRT quantizes and which tensors were - trained with fake quantization. - input_saved_model_dir: the directory to load the SavedModel which contains - the input graph to transforms. Used only when input_graph_def is None. - input_saved_model_tags: list of tags to load the SavedModel. - output_saved_model_dir: if not None, construct a SavedModel using the - returned GraphDef and save it to the specified directory. This option only - works when the input graph is loaded from a SavedModel, i.e. when - input_saved_model_dir is specified and input_graph_def is None. - session_config: the ConfigProto used to create a Session. It's also used as - a template to create a TRT-enabled ConfigProto for conversion. If not - specified, a default ConfigProto will be used. - - Returns: - A GraphDef transformed from input_graph_def (or the SavedModel graph def - loaded from input_saved_model_dir, if input_graph_def is not present), where - all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF - function is added for each of the subgraphs. - - If is_dynamic_op is True, each TRTEngineOp will contain a serialized - subgraph GraphDef, which will be converted to a TRT engine at execution time - and the TRT engine will be cached for future usage. A new TRT engine will be - created each time when none of the cached engines match the input shapes. If - it fails to execute the TRT engine or the number of cached engines reaches - maximum_cached_engines, the op will fall back to call the corresponding TF - function. - - If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT - engine created from the corresponding subgraph. No more engines will be - created on the fly, and the op will fall back to call the corresponding TF - function when it fails to execute the engine. - - Raises: - ValueError: if the combination of the parameters is invalid. - RuntimeError: if the TensorRT library version is incompatible. - """ - compiled_version = get_linked_tensorrt_version() - loaded_version = get_loaded_tensorrt_version() - version_mismatch = False - if loaded_version[0] < compiled_version[0]: - tf_logging.error( - "TensorRT version mismatch. Tensorflow was compiled against " + - "TensorRT %s but library loaded from environment is TensorRT %s" % - (".".join([str(x) for x in compiled_version]), - ".".join([str(x) for x in loaded_version])) + - ". Please make sure that correct version of TensorRT " + - "is available in the system and added to ldconfig or LD_LIBRARY_PATH") - raise RuntimeError("Incompatible TensorRT library version") - for i in zip(loaded_version, compiled_version): - if i[0] != i[1]: - tf_logging.warn("TensorRT mismatch. Compiled against version " + - "%s, but loaded %s. Things may not work" % - (".".join([str(x) for x in compiled_version]), - ".".join([str(x) for x in loaded_version]))) - version_mismatch = True - break - if not version_mismatch: - tf_logging.info("Running against TensorRT version %s" % ".".join( - [str(x) for x in loaded_version])) - - if session_config is None: - session_config = config_pb2.ConfigProto() - - if input_saved_model_tags is None: - input_saved_model_tags = [tag_constants.SERVING] - saved_model_loader = None - grappler_meta_graph_def = None - - if input_graph_def is None: - # Read from SavedModel and freeze the graph if necessary. - if input_saved_model_dir is None: - raise ValueError("input_graph_def and input_saved_model_dir cannot be " - "both None") - with ops.Graph().as_default(): - with session.Session(config=session_config) as sess: - saved_model_loader = loader_impl.SavedModelLoader(input_saved_model_dir) - input_meta_graph_def = saved_model_loader.load(sess, - input_saved_model_tags) - output_node_names = set() - - def _gather_names(tensor_info): - """Get the node names from a TensorInfo.""" - return set( - [tensor_info[key].name.split(":")[0] for key in tensor_info]) - - # Get input and outputs from all SignatureDef. - for key in input_meta_graph_def.signature_def: - signature_def = input_meta_graph_def.signature_def[key] - output_node_names.update(_gather_names(signature_def.inputs)) - output_node_names.update(_gather_names(signature_def.outputs)) - - # Freeze the variables in the SavedModel graph and copy the frozen - # graph over. - frozen_graph_def = graph_util.convert_variables_to_constants( - sess, sess.graph.as_graph_def(add_shapes=True), - list(output_node_names)) - grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef() - grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def) - - # Copy the collections that are not variables. - for key in input_meta_graph_def.collection_def: - # TODO(laigd): currently we use the collection key to filter out - # collections that depend on variable ops, but this may miss some - # other user-defined collections. A better way would be to use - # CollectionDef::NodeList for the filtering. - if key not in [ - "variables", "local_variables", "model_variables", - "trainable_variables", "train_op", "table_initializer" - ]: - grappler_meta_graph_def.collection_def[key].CopyFrom( - input_meta_graph_def.collection_def[key]) - - # Copy other information. - grappler_meta_graph_def.meta_info_def.CopyFrom( - input_meta_graph_def.meta_info_def) - for key in input_meta_graph_def.signature_def: - grappler_meta_graph_def.signature_def[key].CopyFrom( - input_meta_graph_def.signature_def[key]) - # TODO(laigd): maybe add back AssetFileDef. - else: - if output_saved_model_dir is not None: - raise ValueError("output_saved_model_dir cannot be set when " - "input_graph_def is set") - # Create MetaGraphDef from input graph. - graph = ops.Graph() - with graph.as_default(): - importer.import_graph_def(input_graph_def, name="") - grappler_meta_graph_def = saver.export_meta_graph( - graph_def=graph.as_graph_def(add_shapes=True), graph=graph) - if outputs: - output_collection = meta_graph_pb2.CollectionDef() - output_list = output_collection.node_list.value - for i in outputs: - if isinstance(i, ops.Tensor): - output_list.append(_to_bytes(i.name)) - else: - output_list.append(_to_bytes(i)) - # TODO(laigd): use another key as the outputs are really not train_op. - grappler_meta_graph_def.collection_def["train_op"].CopyFrom( - output_collection) - - # Create TRT-enabled ConfigProto. - session_config_with_trt = config_pb2.ConfigProto() - session_config_with_trt.CopyFrom(session_config) - rewriter_config = None - if (session_config_with_trt.HasField("graph_options") and - session_config_with_trt.graph_options.HasField("rewrite_options")): - rewriter_config = session_config_with_trt.graph_options.rewrite_options - rewriter_config_with_trt = get_tensorrt_rewriter_config( - rewriter_config, max_batch_size, max_workspace_size_bytes, precision_mode, - minimum_segment_size, is_dynamic_op, maximum_cached_engines, - cached_engine_batch_sizes, use_calibration) - session_config_with_trt.graph_options.rewrite_options.CopyFrom( - rewriter_config_with_trt) - - # Run Grappler. - transformed_graph_def = tf_optimizer.OptimizeGraph( - session_config_with_trt, grappler_meta_graph_def, graph_id=b"tf_graph") - - # Optionally write the transformed graphdef as SavedModel. - if output_saved_model_dir is not None: - saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) - with ops.Graph().as_default(): - importer.import_graph_def(transformed_graph_def, name="") - # We don't use TRT here. - with session.Session(config=session_config) as sess: - saved_model_builder.add_meta_graph_and_variables( - sess, - input_saved_model_tags, - signature_def_map=grappler_meta_graph_def.signature_def) - # Ignore other meta graphs from the input SavedModel. - saved_model_builder.save() - - return transformed_graph_def +from tensorflow.python.compiler.tensorrt import trt_convert + + +def create_inference_graph( + input_graph_def, + outputs, + max_batch_size=1, + max_workspace_size_bytes=trt_convert.DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, + precision_mode=trt_convert.TrtPrecisionMode.FP32, + minimum_segment_size=3, + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=None, + use_calibration=True, + input_saved_model_dir=None, + input_saved_model_tags=None, + output_saved_model_dir=None, + session_config=None): + return trt_convert.create_inference_graph( + input_graph_def=input_graph_def, + outputs=outputs, + max_batch_size=max_batch_size, + max_workspace_size_bytes=max_workspace_size_bytes, + precision_mode=precision_mode, + minimum_segment_size=minimum_segment_size, + is_dynamic_op=is_dynamic_op, + maximum_cached_engines=maximum_cached_engines, + cached_engine_batches=cached_engine_batches, + use_calibration=use_calibration, + input_saved_model_dir=input_saved_model_dir, + input_saved_model_tags=input_saved_model_tags, + output_saved_model_dir=output_saved_model_dir, + session_config=session_config) def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): - """Convert an existing calibration graph to inference graph. - - Args: - calibration_graph_def: the calibration GraphDef object with calibration data - is_dynamic_op: whether to create dynamic static engines from calibration - - Returns: - New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. - Raises: - RuntimeError: if the returned status message is malformed. - """ - - is_calib_graph = False - for n in calibration_graph_def.node: - if n.op == "TRTEngineOp": - is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s - if not is_calib_graph: - tf_logging.error( - "Not a calib graph. Doesn't seem to contain any calibration nodes.") - return None - graph_str = calibration_graph_def.SerializeToString() - out = calib_convert(graph_str, is_dynamic_op) - status = _to_string(out[0]) - output_graph_def_string = out[1] - del graph_str # Save some memory - if len(status) < 2: - raise _impl.UnknownError(None, None, status) - if status[:2] != "OK": - msg = status.split(";") - if len(msg) == 1: - raise RuntimeError("Status message is malformed {}".format(status)) - # pylint: disable=protected-access - raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), - int(msg[0])) - # pylint: enable=protected-access - output_graph_def = graph_pb2.GraphDef() - output_graph_def.ParseFromString(output_graph_def_string) - del output_graph_def_string # Save some memory - return output_graph_def + return trt_convert.calib_graph_to_infer_graph( + calibration_graph_def=calibration_graph_def, is_dynamic_op=is_dynamic_op) diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc deleted file mode 100644 index 9c3698e5d1cc5d6d8d31a8fcaf03d103f1e1915d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/core/platform/logging.h" - -namespace tensorflow { -namespace tensorrt { - -std::shared_ptr -tensorflow::tensorrt::TRTResourceManager::instance() { - static std::shared_ptr instance_(new TRTResourceManager); - return instance_; -} - -std::shared_ptr -tensorflow::tensorrt::TRTResourceManager::getManager(const string& op_name) { - // mutex is held for lookup only. Most instantiations where mutex will be held - // longer will be during op creation and should be ok. - tensorflow::mutex_lock lock(map_mutex_); - auto s = managers_.find(op_name); - if (s == managers_.end()) { - auto it = managers_.emplace( - op_name, std::make_shared(op_name)); - VLOG(1) << "Returning a new manager " << op_name; - return it.first->second; - } - VLOG(1) << "Returning old manager " << op_name; - return s->second; -} - -} // namespace tensorrt -} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h deleted file mode 100644 index 19f39e6d3db1571573fb290dd2c30fd43ea604ef..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ -#include - -#include -#include -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { -namespace tensorrt { - -class TRTResourceManager { - TRTResourceManager() = default; - - public: - static std::shared_ptr instance(); - // returns a manager for given op, if it doesn't exists it creates one - std::shared_ptr getManager(const string& op_name); - - private: - std::unordered_map> - managers_; - tensorflow::mutex map_mutex_; -}; - -} // namespace tensorrt -} // namespace tensorflow - -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h deleted file mode 100644 index aac9e5c7bd725fc10bcaa04536ebc7be071b4d4c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" -#include "tensorflow/core/framework/resource_mgr.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT - -#include "tensorrt/include/NvInfer.h" - -namespace tensorflow { -namespace tensorrt { - -class TRTCalibrationResource : public tensorflow::ResourceBase { - public: - ~TRTCalibrationResource() { - LOG(INFO) << "Destroying Calibration Resource " << std::endl - << DebugString(); - builder_.reset(); - engine_.reset(); - // We need to manually destroy the builder and engine before the allocator - // is destroyed. - allocator_.reset(); - } - - string DebugString() override { - std::stringstream oss; - using std::dec; - using std::endl; - using std::hex; - oss << " Calibrator = " << hex << calibrator_.get() << dec << endl - << " Builder = " << hex << builder_.get() << dec << endl - << " Engine = " << hex << engine_.get() << dec << endl - << " Logger = " << hex << &logger_ << dec << endl - << " Allocator = " << hex << allocator_.get() << dec << endl - << " Thread = " << hex << thr_.get() << dec << endl; - return oss.str(); - } - - std::unique_ptr calibrator_; - TrtUniquePtrType builder_; - TrtUniquePtrType engine_; - std::unique_ptr allocator_; - tensorflow::tensorrt::Logger logger_; - // TODO(sami): Use threadpool threads! - std::unique_ptr thr_; -}; - -} // namespace tensorrt -} // namespace tensorflow - -#endif -#endif -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_ diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index f30dba59ad55317d7ad7730e4dc66c9aba4e6a6b..5c60d6b589ed6a16276226726d989e949bcbf9d7 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -14,14 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorrt/include/NvInfer.h" diff --git a/tensorflow/contrib/tensorrt/test/manual_test.py b/tensorflow/contrib/tensorrt/test/manual_test.py deleted file mode 100644 index 1187c759b4b5483cbf5afe136401abe86d6ef989..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/test/manual_test.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Basic tests for TF-TensorRT integration.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ast -import os - -from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test -from tensorflow.core.framework import graph_pb2 -from tensorflow.python.platform import gfile -from tensorflow.python.platform import test - - -class ManualTest(trt_test.TfTrtIntegrationTestBase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - super(ManualTest, self).__init__(methodName) - self._params_map = None - - def _GetEnv(self): - """Get an environment variable specifying the manual test parameters. - - The value of the environment variable is the string representation of a dict - which should contain the following keys: - - 'graph_path': the file path to the serialized frozen graphdef - - 'input_names': TfTrtIntegrationTestParams.input_names - - 'input_dims': TfTrtIntegrationTestParams.input_dims - - 'expected_output_dims': TfTrtIntegrationTestParams.expected_output_dims - - 'output_name': the name of op to fetch - - 'expected_engines_to_run': ExpectedEnginesToRun() will return this - - 'expected_engines_to_build': ExpectedEnginesToBuild() will return this - - 'max_batch_size': ConversionParams.max_batch_size - - Returns: - The value of the environment variable. - """ - return os.getenv('TRT_MANUAL_TEST_PARAMS', '') - - def _GetParamsMap(self): - """Parse the environment variable as a dict and return it.""" - if self._params_map is None: - self._params_map = ast.literal_eval(self._GetEnv()) - return self._params_map - - def GetParams(self): - """Testing conversion of manually provided frozen graph.""" - params_map = self._GetParamsMap() - gdef = graph_pb2.GraphDef() - with gfile.Open(params_map['graph_path'], 'rb') as f: - gdef.ParseFromString(f.read()) - return trt_test.TfTrtIntegrationTestParams( - gdef=gdef, - input_names=params_map['input_names'], - input_dims=params_map['input_dims'], - output_names=params_map['output_names'], - expected_output_dims=params_map['expected_output_dims']) - - def GetConversionParams(self, run_params): - """Return a ConversionParams for test.""" - conversion_params = super(ManualTest, self).GetConversionParams(run_params) - params_map = self._GetParamsMap() - if 'max_batch_size' in params_map: - conversion_params = conversion_params._replace( - max_batch_size=params_map['max_batch_size']) - return conversion_params - - def ExpectedEnginesToBuild(self, run_params): - """Return the expected engines to build.""" - return self._GetParamsMap()['expected_engines_to_build'] - - def ExpectedEnginesToRun(self, run_params): - """Return the expected engines to run.""" - params_map = self._GetParamsMap() - if 'expected_engines_to_run' in params_map: - return params_map['expected_engines_to_run'] - return self.ExpectedEnginesToBuild(run_params) - - def ExpectedAbsoluteTolerance(self, run_params): - """The absolute tolerance to compare floating point results.""" - params_map = self._GetParamsMap() - if 'atol' in params_map: - return params_map['atol'] - return 1.e-3 - - def ExpectedRelativeTolerance(self, run_params): - """The relative tolerance to compare floating point results.""" - params_map = self._GetParamsMap() - if 'rtol' in params_map: - return params_map['rtol'] - return 1.e-3 - - def ShouldRunTest(self, run_params): - """Whether to run the test.""" - return len(self._GetEnv()) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py deleted file mode 100644 index d26f26008635733c6c364a98b72b88c1e552f5fe..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Script to test TF-TensorRT integration.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import numpy as np -import six as _six - -# normally we should do import tensorflow as tf and then -# tf.placeholder, tf.constant, tf.nn.conv2d etc but -# it looks like internal builds don't like it so -# importing every module individually - -from tensorflow.contrib import tensorrt as trt -from tensorflow.core.protobuf import config_pb2 as cpb2 -from tensorflow.core.protobuf import rewriter_config_pb2 as rwpb2 -from tensorflow.python.client import session as csess -from tensorflow.python.framework import constant_op as cop -from tensorflow.python.framework import dtypes as dtypes -from tensorflow.python.framework import importer as importer -from tensorflow.python.framework import ops as ops -from tensorflow.python.ops import array_ops as aops -from tensorflow.python.ops import math_ops as mops -from tensorflow.python.ops import nn as nn -from tensorflow.python.ops import nn_ops as nn_ops - - -def py2bytes(inp): - return inp - - -def py3bytes(inp): - return inp.encode("utf-8", errors="surrogateescape") - - -def py2string(inp): - return inp - - -def py3string(inp): - return inp.decode("utf-8") - - -if _six.PY2: - to_bytes = py2bytes - to_string = py2string -else: - to_bytes = py3bytes - to_string = py3string - - -def get_multi_engine_graph_def(mode="FP32"): - """Create a simple graph and return its graph_def.""" - dtype = dtypes.float32 - if mode.upper() == "FP16": - dtype = dtypes.float16 - else: - pass - - g = ops.Graph() - with g.as_default(): - x = aops.placeholder(shape=[None, 3, 7, 5], name="input", dtype=dtype) - with g.name_scope("Global_scope"): - with g.name_scope("first_scope"): - e = cop.constant( - np.random.randn(3, 2, 3, 4), name="weights", dtype=dtype) - conv = nn.conv2d( - input=x, - filter=e, - data_format="NCHW", - strides=[1, 1, 1, 1], - padding="VALID", - name="conv") - b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias1", dtype=dtype) - t = conv * b - - b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias2", dtype=dtype) - q = conv / b - edge = mops.sin(q) - edge1 = mops.cos(conv) - with g.name_scope("test_scope"): - de = edge + edge1 - t -= edge1 - q *= edge - t += q - t -= de - k = aops.squeeze(t, name="output") - print(k.dtype) - return g.as_graph_def() - - -def get_simple_graph_def(): - """Create a simple graph and return its graph_def.""" - g = ops.Graph() - with g.as_default(): - a = aops.placeholder( - dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") - e = cop.constant( - [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], - name="weights", - dtype=dtypes.float32) - conv = nn.conv2d( - input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv") - b = cop.constant( - [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32) - t = nn.bias_add(conv, b, name="biasAdd") - relu = nn.relu(t, "relu") - idty = aops.identity(relu, "ID") - v = nn_ops.max_pool( - idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - aops.squeeze(v, name="output") - return g.as_graph_def() - - -def execute_graph(gdef, dumm_inp): - """Run given graphdef once.""" - print("executing") - gpu_options = None - if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - sessconfig = cpb2.ConfigProto(gpu_options=gpu_options) - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] - with csess.Session(config=sessconfig, graph=g) as sess: - val = sess.run(out, {inp: dumm_inp}) - return val - - -# Use real data that is representative of the inference dataset -# for calibration. For this test script it is random data. -def execute_calibration(gdef, dumm_inp): - """Run given calibration graph multiple times.""" - gpu_options = None - if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] - with csess.Session( - config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: - # run over real calibration data here, we are mimicking a calibration set of - # 30 different batches. Use as much calibration data as you want - for _ in range(30): - val = sess.run(out, {inp: dumm_inp}) - return val - - -def user(multi_engine, - run_graph=execute_graph, - run_calibration=execute_calibration): - """Example function that converts a graph to TFTRT graph.""" - if multi_engine: - inp_dims = (2, 3, 7, 5) - orig_graph = get_multi_engine_graph_def() - else: - inp_dims = (100, 24, 24, 2) - orig_graph = get_simple_graph_def() # use a frozen graph for inference - dummy_input = np.random.random_sample(inp_dims) - # Get optimized graph - trt_graph = trt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2, # minimum number of nodes in an engine - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batch_sizes=[]) - o1 = run_graph(orig_graph, dummy_input) - o2 = run_graph(trt_graph, dummy_input) - o3 = run_graph(trt_graph, dummy_input) - assert np.array_equal(o1, o2) - assert np.array_equal(o3, o2) # sanity check - fp16_graph = trt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2, # minimum number of nodes in an engine - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batch_sizes=[]) - int8_calib_gdef = trt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2, # minimum number of nodes in an engine - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batch_sizes=[]) - o4 = run_graph(fp16_graph, dummy_input) - _ = run_calibration(int8_calib_gdef, dummy_input) - int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) - o5 = run_graph(int8_graph, dummy_input) - print("Is FP32 == FP16? %s (False is possible)" % np.allclose(o1, o4)) - print("Is FP32 == INT8? %s (False is possible)" % np.allclose(o1, o5)) - print("Pass") - - -def auto(multi_engine): - """Run the conversion as an optimization pass.""" - if multi_engine: - inp_dims = (2, 3, 7, 5) - orig_graph = get_multi_engine_graph_def() - else: - inp_dims = (100, 24, 24, 2) - orig_graph = get_simple_graph_def() # use a frozen graph for inference - dummy_input = np.random.random_sample(inp_dims) - opt_config = rwpb2.RewriterConfig() - opt_config.meta_optimizer_iterations = opt_config.ONE - opt_config.optimizers.extend(["constfold", "layout"]) - custom_op = opt_config.custom_optimizers.add() - custom_op.name = "TensorRTOptimizer" - custom_op.parameter_map["minimum_segment_size"].i = 3 - custom_op.parameter_map["precision_mode"].s = to_bytes("FP32") - custom_op.parameter_map["max_batch_size"].i = inp_dims[0] - custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 - print(custom_op) - gpu_options = None - if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - graph_options = cpb2.GraphOptions(rewrite_options=opt_config) - sessconfig = cpb2.ConfigProto( - gpu_options=gpu_options, graph_options=graph_options) - print(sessconfig) - g = ops.Graph() - ops.reset_default_graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=orig_graph, return_elements=["input", "output"], name="") - inp = inp.outputs[0] - out = out.outputs[0] - with csess.Session(config=sessconfig, graph=g) as sess: - val = sess.run(out, {inp: dummy_input}) - print(val.shape) - - -if "__main__" in __name__: - P = argparse.ArgumentParser( - prog="tftrt_test", - description="Example utilization of TensorFlow-TensorRT integration") - P.add_argument( - "--automatic", - "-a", - action="store_true", - help="Do TRT conversion automatically", - default=False) - P.add_argument( - "--multi-engine", - "-m", - action="store_true", - help="Use a graph that will result in 2 engines", - default=False) - flags, unparsed = P.parse_known_args() - if flags.automatic: - auto(flags.multi_engine) - else: - user(flags.multi_engine) diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD index 57797214d1684550aa7ad2664b71d22b504f70ed..e10be88ece8ebba9635af955b3c3410f29e5503c 100644 --- a/tensorflow/contrib/timeseries/examples/BUILD +++ b/tensorflow/contrib/timeseries/examples/BUILD @@ -105,6 +105,7 @@ py_binary( data = ["data/multivariate_periods.csv"], srcs_version = "PY2AND3", tags = ["no_pip"], + visibility = ["//visibility:public"], deps = select({ ":empty_condition": [], "//conditions:default": [], @@ -113,6 +114,7 @@ py_binary( "//tensorflow:tensorflow_py", "//tensorflow/contrib/timeseries/python/timeseries:estimators", "//tensorflow/contrib/timeseries/python/timeseries:model", + "//tensorflow/contrib/timeseries/python/timeseries:state_management", ], ) diff --git a/tensorflow/contrib/timeseries/examples/predict_test.py b/tensorflow/contrib/timeseries/examples/predict_test.py index 678fd71cd8b94ee0be46e10a9a673de55bd44215..b353f85cb5df0cf961d1900b241e4fa1a84a24b4 100644 --- a/tensorflow/contrib/timeseries/examples/predict_test.py +++ b/tensorflow/contrib/timeseries/examples/predict_test.py @@ -43,10 +43,6 @@ class PeriodTrendExampleTest(test.TestCase): self.assertAllEqual([700], mean.shape) self.assertAllEqual([700], upper_limit.shape) self.assertAllEqual([700], lower_limit.shape) - # Check that variance hasn't blown up too much. This is a relatively good - # indication that training was successful. - self.assertLess(upper_limit[-1] - lower_limit[-1], - 1.5 * (upper_limit[0] - lower_limit[0])) def test_ar(self): (times, observed, all_times, mean, @@ -55,7 +51,6 @@ class PeriodTrendExampleTest(test.TestCase): self.assertAllEqual(all_times.shape, mean.shape) self.assertAllEqual(all_times.shape, upper_limit.shape) self.assertAllEqual(all_times.shape, lower_limit.shape) - self.assertLess((upper_limit - lower_limit).mean(), 4.) if __name__ == "__main__": diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index 4b90b596b28efec83aa349782c4874d79b6817c7..4ba814b9e3d3621f9ab924961e2740885fa93b33 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -155,13 +155,16 @@ py_library( py_test( name = "head_test", - size = "large", + size = "medium", srcs = [ "head_test.py", ], - shard_count = 4, + shard_count = 10, srcs_version = "PY2AND3", - tags = ["no_pip_gpu"], # b/63391119 + tags = [ + "no_pip_gpu", # b/63391119 + "notap", # b/124520733 + ], deps = [ ":estimators", ":feature_keys", @@ -281,6 +284,7 @@ py_library( "input_pipeline.py", ], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ ":feature_keys", ":model_utils", @@ -361,9 +365,10 @@ py_library( srcs_version = "PY2AND3", deps = [ ":feature_keys", + ":math_utils", ":model", ":model_utils", - "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:constant_op", diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py index bcadf4094e1e79fff1685515f2bde0b88f717cac..3626701d24163ef52564b42d8a630bd9c5a788eb 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py @@ -18,9 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import distributions - from tensorflow.contrib.rnn.python.ops import lstm_ops +from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import model_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures @@ -462,11 +461,12 @@ class ARModel(model.TimeSeriesModel): if self.loss == ARModel.NORMAL_LIKELIHOOD_LOSS: covariance = prediction_ops["covariance"] sigma = math_ops.sqrt(gen_math_ops.maximum(covariance, 1e-5)) - normal = distributions.Normal(loc=targets, scale=sigma) - loss_op = -math_ops.reduce_sum(normal.log_prob(prediction)) + loss_op = -math_ops.reduce_sum( + math_utils.normal_log_prob(targets, sigma, prediction)) else: assert self.loss == ARModel.SQUARED_LOSS, self.loss - loss_op = math_ops.reduce_sum(math_ops.square(prediction - targets)) + loss_op = math_ops.reduce_sum( + math_ops.squared_difference(prediction, targets)) loss_op /= math_ops.cast( math_ops.reduce_prod(array_ops.shape(targets)), loss_op.dtype) return loss_op @@ -965,16 +965,11 @@ class AnomalyMixtureARModel(ARModel): anomaly_variance = prediction_ops["anomaly_params"] anomaly_sigma = math_ops.sqrt( gen_math_ops.maximum(anomaly_variance, 1e-5)) - normal = distributions.Normal(loc=targets, scale=anomaly_sigma) - log_prob = normal.log_prob(prediction) + log_prob = math_utils.normal_log_prob(targets, anomaly_sigma, prediction) else: assert self._anomaly_distribution == AnomalyMixtureARModel.CAUCHY_ANOMALY anomaly_scale = prediction_ops["anomaly_params"] - cauchy = distributions.StudentT( - df=array_ops.ones([], dtype=anomaly_scale.dtype), - loc=targets, - scale=anomaly_scale) - log_prob = cauchy.log_prob(prediction) + log_prob = math_utils.cauchy_log_prob(targets, anomaly_scale, prediction) return log_prob def loss_op(self, targets, prediction_ops): @@ -983,8 +978,7 @@ class AnomalyMixtureARModel(ARModel): covariance = prediction_ops["covariance"] # Normal data log probability. sigma = math_ops.sqrt(gen_math_ops.maximum(covariance, 1e-5)) - normal1 = distributions.Normal(loc=targets, scale=sigma) - log_prob1 = normal1.log_prob(prediction) + log_prob1 = math_utils.normal_log_prob(targets, sigma, prediction) log_prob1 += math_ops.log(1 - self._anomaly_prior_probability) # Anomaly log probability. log_prob2 = self._anomaly_log_prob(targets, prediction_ops) diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index aab330643862c1ccf073d2a0e34e1c475b1ec15f..b7375e5055e29efea3f23c3b9b9f3af59f45495b 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -21,6 +21,8 @@ from __future__ import print_function import collections import math +import numpy as np + from tensorflow.contrib import lookup from tensorflow.contrib.layers.python.layers import layers @@ -43,6 +45,32 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest +def normal_log_prob(loc, scale, x): + """Computes the Normal log pdf.""" + z = (x - loc) / scale + return -0.5 * (math_ops.square(z) + + np.log(2. * np.pi) + math_ops.log(scale)) + + +def cauchy_log_prob(loc, scale, x): + """Computes the Cauchy log pdf.""" + z = (x - loc) / scale + return (-np.log(np.pi) - math_ops.log(scale) - + math_ops.log1p(math_ops.square(z))) + + +def mvn_tril_log_prob(loc, scale_tril, x): + """Computes the MVN log pdf under tril scale. Doesn't handle batches.""" + x0 = x - loc + z = linalg_ops.matrix_triangular_solve( + scale_tril, x0[..., array_ops.newaxis])[..., 0] + log_det_cov = 2. * math_ops.reduce_sum(math_ops.log( + array_ops.matrix_diag_part(scale_tril)), axis=-1) + d = math_ops.cast(array_ops.shape(scale_tril)[-1], log_det_cov.dtype) + return -0.5 * (math_ops.reduce_sum(math_ops.square(z), axis=-1) + + d * np.log(2. * np.pi) + log_det_cov) + + def clip_covariance( covariance_matrix, maximum_variance_ratio, minimum_variance): """Enforce constraints on a covariance matrix to improve numerical stability. diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD index 125750e7639ad40c481472a93353e6fb7055be96..cf5e749042afd83f927a3d22edfd3a9538ab2ffd 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD @@ -78,7 +78,6 @@ py_library( srcs = ["kalman_filter.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/timeseries/python/timeseries:math_utils", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -235,7 +234,6 @@ py_library( srcs = ["filtering_postprocessor.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/timeseries/python/timeseries:math_utils", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py index e9e2ac0aaf4c4d6c41f5007662f261af3de9bbd1..3fa2fbd9f77cb887c30fde264815728ca345f45a 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py @@ -22,8 +22,6 @@ import abc import six -from tensorflow.contrib import distributions - from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.python.framework import dtypes @@ -91,10 +89,10 @@ def cauchy_alternative_to_gaussian(current_times, current_values, outputs): """ del current_times # unused cauchy_scale = math_utils.entropy_matched_cauchy_scale(outputs["covariance"]) - individual_log_pdfs = distributions.StudentT( - df=array_ops.ones([], dtype=current_values.dtype), + individual_log_pdfs = math_utils.cauchy_log_prob( loc=outputs["mean"], - scale=cauchy_scale).log_prob(current_values) + scale=cauchy_scale, + x=current_values) return math_ops.reduce_sum(individual_log_pdfs, axis=1) diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py index a614386121e000961bf8b32625a28e1251654320..c0ec797bc5b7c41ca996c807840ce38311201f87 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import distributions - from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.python.framework import dtypes @@ -137,9 +135,10 @@ class KalmanFilter(object): with ops.control_dependencies([non_negative_assert]): observation_covariance_cholesky = linalg_ops.cholesky( symmetrized_observation_covariance) - log_prediction_prob = distributions.MultivariateNormalTriL( - predicted_observation, observation_covariance_cholesky).log_prob( - observation) + log_prediction_prob = math_utils.mvn_tril_log_prob( + loc=predicted_observation, + scale_tril=observation_covariance_cholesky, + x=observation) (posterior_state, posterior_state_var) = self.posterior_from_prior_state( prior_state=estimated_state, diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 4bf3a0463d9046eea2f60e9154fca1357e728215..7c1661d20f15f94a929a46dafc79d59ca73e53cb 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -1,15 +1,15 @@ # Description: Operations defined for Cloud TPUs -licenses(["notice"]) # Apache 2.0 - load( "//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", + "tf_py_test", ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_py_test") + +licenses(["notice"]) # Apache 2.0 package( default_visibility = [ @@ -23,17 +23,12 @@ package( ], ) -cc_library( - name = "all_ops", +py_library( + name = "tpu_py", + srcs = ["python/ops/tpu_ops.py"], + srcs_version = "PY2AND3", deps = [ - ":cross_replica_ops_op_lib", - ":heartbeat_ops_op_lib", - ":host_compute_ops_op_lib", - ":infeed_ops_op_lib", - ":outfeed_ops_op_lib", - ":replication_ops_op_lib", - ":tpu_configuration_ops_op_lib", - ":tpu_embedding_ops_op_lib", + "//tensorflow/python/tpu:tpu_py", ], ) @@ -42,25 +37,14 @@ py_library( srcs = ["python/tpu/async_checkpoint.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:summary_ops_v2", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/tpu:async_checkpoint", ], ) py_library( name = "tpu_estimator", srcs = [ + "python/tpu/_tpu_estimator_embedding.py", "python/tpu/error_handling.py", "python/tpu/tpu_config.py", "python/tpu/tpu_context.py", @@ -70,86 +54,24 @@ py_library( srcs_version = "PY2AND3", deps = [ ":async_checkpoint", + ":feature_column", + ":functional", + ":tpu_embedding", ":tpu_lib", "//tensorflow/contrib/training:training_py", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:session", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:summary_ops_v2", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/estimator:util", - "@six_archive//:six", - ], -) - -tf_gen_op_libs( - op_lib_names = [ - "cross_replica_ops", - "heartbeat_ops", - "host_compute_ops", - "infeed_ops", - "outfeed_ops", - "replication_ops", - "tpu_configuration_ops", - "tpu_embedding_ops", - ], - deps = [ - "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc", - "//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils", - "//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils", - "//tensorflow/core:lib", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:protos_all_cc", + "//tensorflow/python/tpu:tpu_estimator", ], ) -tf_custom_op_library( - name = "python/ops/_tpu_ops.so", - srcs = [ - "ops/cross_replica_ops.cc", - "ops/heartbeat_ops.cc", - "ops/host_compute_ops.cc", - "ops/infeed_ops.cc", - "ops/outfeed_ops.cc", - "ops/replication_ops.cc", - "ops/tpu_configuration_ops.cc", - "ops/tpu_embedding_ops.cc", - ], - deps = [ - "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc", - "//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils", - "//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils", - "//tensorflow/core:lib_proto_parsing", - ], -) - -tf_gen_op_wrapper_py( - name = "tpu_ops", - hidden = [ - "SendTPUEmbeddingGradients", - "EnqueueTPUEmbeddingIntegerBatch", - "EnqueueTPUEmbeddingSparseBatch", - "EnqueueTPUEmbeddingSparseTensorBatch", +py_library( + name = "functional", + srcs = ["python/tpu/functional.py"], + srcs_version = "PY2AND3", + visibility = [ + "//visibility:public", ], deps = [ - ":cross_replica_ops_op_lib", - ":heartbeat_ops_op_lib", - ":host_compute_ops_op_lib", - ":infeed_ops_op_lib", - ":outfeed_ops_op_lib", - ":replication_ops_op_lib", - ":tpu_configuration_ops_op_lib", - ":tpu_embedding_ops_op_lib", + "//tensorflow/python/tpu:functional", ], ) @@ -158,30 +80,7 @@ py_library( srcs = ["python/profiler/__init__.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/tpu/profiler:tpu_profiler_analysis_pb2_grpc", - "//tensorflow/contrib/tpu/profiler:tpu_profiler_analysis_proto_py", - "//tensorflow/contrib/tpu/profiler:trace_events_proto_py", - "//tensorflow/python:util", - ], -) - -tf_custom_op_py_library( - name = "tpu_py", - srcs = glob(["python/ops/*.py"]), - dso = [":python/ops/_tpu_ops.so"], - kernels = [ - ":all_ops", - ], - srcs_version = "PY2AND3", - deps = [ - ":profiler", - ":tpu_ops", - "//tensorflow/contrib/compiler:xla", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform", - "//tensorflow/python:util", + "//tensorflow/python/tpu/profiler", ], ) @@ -193,10 +92,12 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":feature_column", ":keras_support", # split out to avoid cycle with tpu_strategy ":tpu_embedding", ":tpu_estimator", ":tpu_lib", + "//tensorflow/python/tpu", ], ) @@ -212,7 +113,6 @@ py_library( "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", "//third_party/cloud_tpu/models/keras_colab:__subpackages__", - "//third_party/cloud_tpu/models/mnist_keras:__subpackages__", "//third_party/cloud_tpu/models/resnet50_keras:__subpackages__", ], deps = [ @@ -220,8 +120,8 @@ py_library( "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/distribute", "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", "//tensorflow/core:protos_all_py", + "//tensorflow/core/protobuf/tpu:compilation_result_proto_py", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -261,29 +161,12 @@ py_library( srcs_version = "PY2AND3", deps = [ ":datasets", + ":functional", ":profiler", ":tpu_py", - "//tensorflow/compiler/xla/experimental/xla_sharding", - "//tensorflow/compiler/xla/python_api:xla_shape", "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/compiler:xla", - "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", - "//tensorflow/contrib/tpu/proto:optimization_parameters_proto_py", - "//tensorflow/contrib/tpu/proto:topology_proto_py", - "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_py", - "//tensorflow/contrib/tpu/proto:tpu_embedding_output_layout_proto_py", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:control_flow_util", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework", - "//tensorflow/python:framework_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/losses", + "//tensorflow/python/tpu:tpu_lib", ], ) @@ -294,121 +177,28 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/ops:readers", - ], -) - -tf_py_test( - name = "datasets_test", - srcs = ["python/tpu/datasets_test.py"], - additional_deps = [ - "//tensorflow/python:client_testlib", - ":datasets", + "//tensorflow/python/tpu:datasets", ], - flaky = 1, # TODO(b/117363808): fails 1/1000 OSS runs - grpc_enabled = True, ) -tf_py_test( - name = "tpu_test", - size = "small", - srcs = ["python/tpu/tpu_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework", - "//tensorflow/python:layers", - ], - tags = ["no_windows"], # TODO: needs investigation on Windows -) - -tf_py_test( - name = "tpu_sharding_test", - size = "small", - srcs = ["python/tpu/tpu_sharding_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - ], -) - -tf_py_test( - name = "bfloat16_test", - size = "small", - srcs = ["python/tpu/bfloat16_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - ], -) - -tf_py_test( - name = "tpu_infeed_test", - size = "small", - srcs = ["python/tpu/tpu_infeed_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", - ], -) - -tf_py_test( - name = "tpu_config_test", - size = "small", - srcs = ["python/tpu/tpu_config_test.py"], - additional_deps = [ - ":tpu_estimator", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", - ], -) - -tf_py_test( - name = "tpu_estimator_signals_test", - size = "small", - srcs = ["python/tpu/tpu_estimator_signals_test.py"], - additional_deps = [ - ":tpu_estimator", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", +py_library( + name = "tpu_embedding", + srcs = [ + "python/tpu/tpu_embedding.py", + "python/tpu/tpu_embedding_gradient.py", ], -) - -tf_py_test( - name = "topology_test", - size = "medium", - srcs = ["python/tpu/topology_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:framework_test_lib", + srcs_version = "PY2AND3", + deps = [ + ":tpu_lib", + "//tensorflow/python/tpu:tpu_embedding", ], ) py_library( - name = "tpu_embedding", - srcs = ["python/tpu/tpu_embedding.py"], - srcs_version = "PY2AND3", + name = "feature_column", + srcs = ["python/tpu/feature_column.py"], deps = [ - "//tensorflow/contrib/tpu:tpu_ops", - "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "@six_archive//:six", + ":tpu_lib", + "//tensorflow/python/tpu:feature_column", ], ) diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index 541fbf33a302a4d850422885fdbbc438bd6b9b7b..e2ce77e118182bb07193cbac82e176d3b2057e17 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -2,35 +2,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_cc") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos") - -tf_proto_library( - name = "tpu_profiler_proto", - srcs = ["tpu_profiler.proto"], - has_services = 1, - cc_api_version = 2, - cc_grpc_version = 1, - protodeps = [":op_profile_proto"] + tf_additional_all_protos(), - visibility = ["//visibility:public"], -) - -cc_library( - name = "dump_tpu_profile", - srcs = ["dump_tpu_profile.cc"], - hdrs = ["dump_tpu_profile.h"], - visibility = ["//visibility:public"], - deps = [ - ":op_profile_proto_cc", - ":tpu_profiler_proto_cc", - ":trace_events_proto_cc", - ":trace_events_to_json", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], -) cc_library( name = "version", @@ -43,71 +14,13 @@ tf_cc_binary( srcs = [ "capture_tpu_profile.cc", ], + tags = ["no_windows"], visibility = ["//visibility:public"], deps = [ - ":dump_tpu_profile", - ":tpu_profiler_analysis_proto_cc", - ":tpu_profiler_proto_cc", ":version", - "//tensorflow:grpc++", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/platform/cloud:gcs_file_system", + "//tensorflow/core/profiler/rpc/client:capture_profile", ], ) - -tf_proto_library( - name = "trace_events_proto", - srcs = ["trace_events.proto"], - cc_api_version = 2, - visibility = ["//visibility:public"], -) - -cc_library( - name = "trace_events_to_json", - srcs = ["trace_events_to_json.cc"], - hdrs = ["trace_events_to_json.h"], - deps = [ - ":trace_events_proto_cc", - "//tensorflow/core:lib", - "@jsoncpp_git//:jsoncpp", - ], -) - -tf_cc_test( - name = "trace_events_to_json_test", - srcs = ["trace_events_to_json_test.cc"], - deps = [ - ":trace_events_to_json", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@jsoncpp_git//:jsoncpp", - ], -) - -tf_proto_library( - name = "op_profile_proto", - srcs = ["op_profile.proto"], - cc_api_version = 2, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "tpu_profiler_analysis_proto", - srcs = ["tpu_profiler_analysis.proto"], - has_services = 1, - cc_api_version = 2, - cc_grpc_version = 1, - protodeps = [":tpu_profiler_proto"] + tf_additional_all_protos(), - visibility = ["//visibility:public"], -) - -py_library( - name = "tpu_profiler_analysis_pb2_grpc", - srcs = ["tpu_profiler_analysis_pb2_grpc.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [":tpu_profiler_analysis_proto_py"], -) diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index 1c5ea2d997a58ca57ddc212ffd56aad525e961da..f11d1a9f37eeb19b95a876bd68575022e6b91521 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -18,235 +18,11 @@ limitations under the License. // Initiates a TPU profiling on the TPUProfiler service at service_addr, // receives and dumps the profile data to a tensorboard log directory. -#include "grpcpp/grpcpp.h" - -#include -#include -#include - -#include "tensorflow/contrib/tpu/profiler/dump_tpu_profile.h" -#include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h" -#include "tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.grpc.pb.h" #include "tensorflow/contrib/tpu/profiler/version.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/profiler/rpc/client/capture_profile.h" #include "tensorflow/core/util/command_line_flags.h" -namespace tensorflow { -namespace tpu { -namespace { - -using ::tensorflow::TPUProfileAnalysis; -using ::tensorflow::TPUProfiler; - -constexpr uint64 kMaxEvents = 1000000; - -string GetCurrentTimeStampAsString() { - char s[128]; - std::time_t t = std::time(nullptr); - CHECK_NE(std::strftime(s, sizeof(s), "%F_%T", std::localtime(&t)), 0); - return s; -} - -Status ValidateHostPortPair(const string& host_port) { - uint32 port; - std::vector parts = str_util::Split(host_port, ':'); - // Must be host:port, port must be a number, host must not contain a '/', - // host also must not be empty. - if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) || - parts[0].find("/") != string::npos || parts[0].empty()) { - return errors::InvalidArgument("Could not interpret \"", host_port, - "\" as a host-port pair."); - } - return Status::OK(); -} - -ProfileRequest PopulateProfileRequest(int duration_ms, - const string& repository_root, - const string& session_id, - const ProfileOptions& opts) { - ProfileRequest request; - request.set_duration_ms(duration_ms); - request.set_max_events(kMaxEvents); - if (tensorflow::str_util::StartsWith(repository_root, "gs://")) { - // For backward compatibilities, only generate tracetable etc when the - // user provide a GCS path for model directory. - request.set_repository_root(repository_root); - request.set_session_id(session_id); - } - request.add_tools("op_profile"); - request.add_tools("input_pipeline"); - request.add_tools("memory_viewer"); - request.add_tools("overview_page"); - *request.mutable_opts() = opts; - return request; -} - -// Returns whether the returned trace is empty. -// Failure are handled by CHECK, i.e. abort() -bool Profile(const string& service_addr, const string& logdir, int duration_ms, - const string& repository_root, const string& session_id, - const ProfileOptions& opts) { - ProfileRequest request = - PopulateProfileRequest(duration_ms, repository_root, session_id, opts); - - ::grpc::ClientContext context; - ::grpc::ChannelArguments channel_args; - // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their - // `ValidateHostPortPair` checks for empty host string case. - channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, - std::numeric_limits::max()); - std::unique_ptr stub = - TPUProfiler::NewStub(::grpc::CreateCustomChannel( - "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), - channel_args)); - ProfileResponse response; - TF_QCHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response))); - - if (!response.encoded_trace().empty()) { - TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile( - logdir, session_id, "", response, &std::cout)); - // Print this at the end so that it's not buried in irrelevant LOG messages. - std::cout - << "NOTE: using the trace duration " << duration_ms << "ms." - << std::endl - << "Set an appropriate duration (with --duration_ms) if you " - "don't see a full step in your trace or the captured trace is too " - "large." - << std::endl; - } - - return response.encoded_trace().empty(); -} - -// Start a new profiling session that include all the hosts included in -// hostnames, for the time interval of duration_ms. Possibly save the profiling -// result in the directory specified by repository_root and session_id. -bool NewSession(const string& service_addr, - const std::vector& hostnames, - int duration_ms, const string& repository_root, - const string& session_id, const ProfileOptions& opts) { - NewProfileSessionRequest new_session_request; - *new_session_request.mutable_request() = - PopulateProfileRequest(duration_ms, repository_root, session_id, opts); - new_session_request.set_repository_root(repository_root); - new_session_request.set_session_id(session_id); - for (const auto& hostname : hostnames) { - new_session_request.add_hosts(hostname); - } - - ::grpc::ClientContext context; - ::grpc::ChannelArguments channel_args; - // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their - // `ValidateHostPortPair` checks for empty host string case. - channel_args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - // TODO(jiesun): GRPC support following relevant naming scheme: - // 1. dns:///host:port - // 2. ipv4:host:port or ipv6:[host]:port - // We might need to change the prefix which depends on what TPU name resolver - // will give us. - std::unique_ptr stub = - TPUProfileAnalysis::NewStub(::grpc::CreateCustomChannel( - "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), - channel_args)); - NewProfileSessionResponse new_session_response; - TF_QCHECK_OK(FromGrpcStatus( - stub->NewSession(&context, new_session_request, &new_session_response))); - - std::cout << "Profile session succeed for host(s):" - << str_util::Join(hostnames, ",") << std::endl; - return new_session_response.empty_trace(); -} - -// Starts tracing on a single or multiple TPU hosts and saves the result in the -// given logdir. If no trace was collected, retries tracing for -// num_tracing_attempts. -void StartTracing(const tensorflow::string& service_addr, - const tensorflow::string& logdir, - const tensorflow::string& workers_list, - bool include_dataset_ops, int duration_ms, - int num_tracing_attempts) { - // Use the current timestamp as the run name. - tensorflow::string session_id = GetCurrentTimeStampAsString(); - constexpr char kProfilePluginDirectory[] = "plugins/profile/"; - tensorflow::string repository_root = - io::JoinPath(logdir, kProfilePluginDirectory); - std::vector hostnames = - tensorflow::str_util::Split(workers_list, ","); - - bool empty_trace = false; - int remaining_attempts = num_tracing_attempts; - tensorflow::ProfileOptions opts; - opts.set_include_dataset_ops(include_dataset_ops); - while (true) { - std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. " - << "Remaining attempt(s): " << remaining_attempts-- << std::endl; - if (hostnames.empty()) { - empty_trace = tensorflow::tpu::Profile(service_addr, logdir, duration_ms, - repository_root, session_id, opts); - } else { - tensorflow::string tpu_master = service_addr; - empty_trace = - tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms, - repository_root, session_id, opts); - } - if (remaining_attempts <= 0 || !empty_trace) break; - std::cout << "No trace event is collected. Automatically retrying." - << std::endl - << std::endl; - } - - if (empty_trace) { - std::cout << "No trace event is collected after " << num_tracing_attempts - << " attempt(s). " - << "Perhaps, you want to try again (with more attempts?)." - << std::endl - << "Tip: increase number of attempts with --num_tracing_attempts." - << std::endl; - } -} - -MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level) { - MonitorRequest request; - request.set_duration_ms(duration_ms); - request.set_monitoring_level(monitoring_level); - return request; -} - -// Repeatedly collects profiles and shows user-friendly metrics for -// 'num_queries' time(s). -void StartMonitoring(const tensorflow::string& service_addr, int duration_ms, - int monitoring_level, int num_queries) { - for (int query = 0; query < num_queries; ++query) { - MonitorRequest request = - PopulateMonitorRequest(duration_ms, monitoring_level); - - ::grpc::ClientContext context; - ::grpc::ChannelArguments channel_args; - channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, - std::numeric_limits::max()); - std::unique_ptr stub = - TPUProfiler::NewStub(::grpc::CreateCustomChannel( - "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), - channel_args)); - MonitorResponse response; - TF_QCHECK_OK(FromGrpcStatus(stub->Monitor(&context, request, &response))); - - std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1 - << "):\n\n" - << response.data() << std::flush; - } -} - -} // namespace -} // namespace tpu -} // namespace tensorflow - int main(int argc, char** argv) { tensorflow::string FLAGS_service_addr; tensorflow::string FLAGS_logdir; @@ -300,8 +76,9 @@ int main(int argc, char** argv) { std::cout << usage.c_str() << std::endl; return 2; } - tensorflow::Status status = - tensorflow::tpu::ValidateHostPortPair(FLAGS_service_addr); + tensorflow::Status status; + status = + tensorflow::profiler::client::ValidateHostPortPair(FLAGS_service_addr); if (!status.ok()) { std::cout << status.error_message() << std::endl; std::cout << usage.c_str() << std::endl; @@ -324,12 +101,17 @@ int main(int argc, char** argv) { << FLAGS_service_addr << " for " << duration_ms << "ms and show metrics for " << num_queries << " time(s)." << std::endl; - tensorflow::tpu::StartMonitoring(FLAGS_service_addr, duration_ms, - FLAGS_monitoring_level, num_queries); + tensorflow::profiler::client::StartMonitoring( + FLAGS_service_addr, duration_ms, FLAGS_monitoring_level, num_queries); } else { - tensorflow::tpu::StartTracing(FLAGS_service_addr, FLAGS_logdir, - FLAGS_workers_list, FLAGS_include_dataset_ops, - duration_ms, num_tracing_attempts); + status = tensorflow::profiler::client::StartTracing( + FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list, + FLAGS_include_dataset_ops, duration_ms, num_tracing_attempts); + if (!status.ok()) { + std::cout << status.error_message() << std::endl; + std::cout << usage.c_str() << std::endl; + return 2; + } } return 0; } diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index f27ae38e0434991da7475e631be1c6cb4a463118..807cf26fe983b4ebe17695d6f4f90ecfc0e0cbf5 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -33,7 +33,7 @@ setup( long_description='Tools for capture TPU profile', url='https://www.tensorflow.org/tfrc/', author='Google Inc.', - author_email='opensource@google.com', + author_email='packages@tensorflow.org', packages=['cloud_tpu_profiler'], package_data={ 'cloud_tpu_profiler': ['data/*'], diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 6a6eba282a12d68cc3cd4e46a46a1b4190fb737b..8605bae5c128513186d8c03835dcf49d3e4b6fd9 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -1,389 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Operations for TPUs.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import platform - -from tensorflow.contrib.tpu.python.tpu import tpu_function -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.platform import tf_logging as logging - -if platform.system() != "Windows": - # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.contrib.tpu.ops import gen_tpu_ops - from tensorflow.contrib.tpu.ops.gen_tpu_ops import * - - from tensorflow.contrib.util import loader - from tensorflow.python.platform import resource_loader - # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - - _tpu_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("_tpu_ops.so")) - - def _create_default_group_assignment(): - num_shards = tpu_function.get_tpu_context().number_of_shards - if num_shards is None: - logging.warning( - "cross_replica_sum should be used within a tpu_shard_context, but " - "got unset number_of_shards. Assuming 1.") - num_shards = 1 - group_assignment = [list(range(num_shards))] - return group_assignment - - def all_to_all(x, - concat_dimension, - split_dimension, - split_count, - group_assignment=None, - name=None): - """Exchange data across TPU replicas. - - Args: - x: The local tensor. - concat_dimension: The dimension number to concatenate. - split_dimension: The dimension number to split. - split_count: The number of splits, this number must equal to the sub-group - size(group_assignment.get_shape()[1]) - group_assignment: Optional 2d int32 lists with shape [num_groups, - num_replicas_per_group]. `group_assignment[i]` represents the replica - ids in the ith subgroup. - name: Optional op name. - - Returns: - A `Tensor` which is concatenated by data from different replicas. - """ - if group_assignment is None: - group_assignment = _create_default_group_assignment() - return gen_tpu_ops.all_to_all( - x, - group_assignment, - concat_dimension=concat_dimension, - split_dimension=split_dimension, - split_count=split_count, - name=name) - - @ops.RegisterGradient("AllToAll") - def _all_to_all_grad(op, grad): - # The gradient of a all-to-all is also a all-to-all but the - # split_dimension and concat_dimension is swapped. - # The graident with respect to group_assignment is None. - return [ - gen_tpu_ops.all_to_all( - grad, - op.inputs[1], - concat_dimension=op.get_attr("split_dimension"), - split_dimension=op.get_attr("concat_dimension"), - split_count=op.get_attr("split_count")), None - ] - - def cross_replica_sum(x, group_assignment=None, name=None): - """Sum the input tensor across replicas according to group_assignment. - - Args: - x: The local tensor to the sum. - group_assignment: Optional 2d int32 lists with shape [num_groups, - num_replicas_per_group]. `group_assignment[i]` represents the replica - ids in the ith subgroup. - name: Optional op name. - - Returns: - A `Tensor` which is summed across replicas. - """ - if group_assignment is None: - group_assignment = _create_default_group_assignment() - - return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name) - - def collective_permute(x, source_target_pairs, name=None): - """Permute the input tensor across replicas given source_target_pairs. - - For each source_target_pair , we send replica a's input to replica b. - Each replica id must only appear once in the source column. Also it must - only appear once in the target column. - For the replica id not in the target column, this op returns a zero tensor - with the same shape and dtype of the input x. - - For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing - source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs: - `[0, A, B, C]`. - - Args: - x: The local tensor to be permuted. - source_target_pairs: 2d int lists with shape [num_pairs, 2]. - source_target_pairs[i][0] represents the source replica id and - source_target_pairs[i][1] represents the target replica id. - name: Optional op name. - - Returns: - A `Tensor` which is permuted. - """ - return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name) - - @ops.RegisterGradient("CollectivePermute") - def _collective_permute_grad(op, grad): - # The gradient of a collective permute operation is also a collective - # permute, but with source/target pairs reversed. The gradient with respect - # to input argument `source_target_pairs` is `None`. - source_target_pairs = op.inputs[1][:, ::-1] - return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None] - - @ops.RegisterGradient("CrossReplicaSum") - def _cross_replica_sum_grad(op, grad): - # The gradient of a cross replica sum is also a cross-replica sum. - # The gradient with respect to group_assignment is None. - return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None] - - # This extra type checking exists to give a more helpful error message in - # the common case that uint8 and int64 values are infed. Remove when both - # types are supported. - - _SUPPORTED_INFEED_DTYPES = set([ - dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32, - dtypes.complex64 - ]) - - def infeed_dequeue(dtype, shape, name=None): - """A placeholder op for a value that will be fed into the computation. - - Args: - dtype: A `tf.DType`. The type of elements in the tensor. - shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor. - name: A name for the operation (optional). - - Returns: - A `Tensor` of type `dtype`. - A tensor that will be provided using the infeed mechanism. - - Raises: - TypeError: If 'dtype` is not a supported infeed type. - """ - if dtype not in _SUPPORTED_INFEED_DTYPES: - raise TypeError( - "{} is not a supported TPU infeed type. Supported types are: " - "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) - - return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name) - - # pylint: disable=redefined-outer-name - def infeed_dequeue_tuple(dtypes, shapes, name=None): - """A placeholder op for values fed into the TPU simultaneously as a tuple. - - Args: - dtypes: A list of `tf.DType`s that has length `>= 1`. - The element types of each element in `outputs`. - shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). - The shapes of each tensor in `outputs`. - name: A name for the operation (optional). - - Returns: - A list of `Tensor` objects of type `dtypes`. - A list of tensors that will be provided using the infeed mechanism. - - Raises: - TypeError: If a type in 'dtypes` is not a supported infeed type. - """ - for dtype in dtypes: - if dtype not in _SUPPORTED_INFEED_DTYPES: - raise TypeError( - "{} is not a supported TPU infeed type. Supported types are: " - "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) - return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name) - # pylint: enable=redefined-outer-name - - # pylint: disable=protected-access - def send_tpu_embedding_gradients(inputs, - config, - learning_rates=None, - name=None): - """A placeholder op for feeding per-sample gradients to the embedding layer. - - Args: - inputs: A TensorList of gradients with which to update embedding tables. - Contains one tensor per embedding table in the model. - config: Serialized TPUEmbeddingConfiguration proto. - learning_rates: A TensorList of float32 scalars, one for each embedding - table, containing the learning rates for each table when dynamic - learning rate is enabled through the OptimizationParameters in - TPUEmbeddingConfiguration. When the learning rate is constant, the list - should be empty (optional). - name: A name for the operation (optional). - - Returns: - A SendTPUEmbeddingGradients operation. - """ - if learning_rates is None: - learning_rates = [] - return gen_tpu_ops._send_tpu_embedding_gradients( - inputs=inputs, learning_rates=learning_rates, config=config, name=name) - - - send_tpu_embedding_gradients.__doc__ = ( - gen_tpu_ops._send_tpu_embedding_gradients.__doc__) - - # pylint: disable=protected-access - def enqueue_tpu_embedding_integer_batch(batch, - device_ordinal, - mode_override=None, - name=None): - """A placeholder op for enqueueing embedding IDs to the TPU. - - Args: - batch: A list of 1D tensors, one for each embedding table, containing the - indices into the tables. - device_ordinal: The TPU device to use. Should be >= 0 and less than the - number of TPU cores in the task on which the node is placed. - mode_override: A string input that overrides the mode specified in the - TPUEmbeddingConfiguration. Supported values are {'unspecified', - 'inference', 'training', 'backward_pass_only'}. When set to - 'unspecified', the mode set in TPUEmbeddingConfiguration is used, - otherwise mode_override is used (optional). - name: A name for the operation (optional). - - Returns: - An EnqueueTPUEmbeddingIntegerBatch operation. - """ - if mode_override is None: - mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_integer_batch( - batch=batch, - device_ordinal=device_ordinal, - mode_override=mode_override, - name=name) - - enqueue_tpu_embedding_integer_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_integer_batch.__doc__) - - # pylint: disable=protected-access - def enqueue_tpu_embedding_sparse_batch(sample_indices, - embedding_indices, - aggregation_weights, - device_ordinal, - combiners=None, - mode_override=None, - name=None): - """A placeholder op for enqueueing embedding IDs to the TPU. - - Args: - sample_indices: A list of rank 1 Tensors specifying the training example - and feature to which the corresponding embedding_indices and - aggregation_weights values belong. sample_indices[i] must equal b * nf + - f, where nf is the number of features from the corresponding table, f is - in [0, nf), and b is in [0, batch size). - embedding_indices: A list of rank 1 Tensors, indices into the embedding - tables. - aggregation_weights: A list of rank 1 Tensors containing per sample -- - i.e. per (training example, feature) -- aggregation weights. - device_ordinal: The TPU device to use. Should be >= 0 and less than the - number of TPU cores in the task on which the node is placed. - combiners: A list of string scalars, one for each embedding table that - specify how to normalize the embedding activations after weighted - summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is - invalid to have the sum of the weights be 0 for 'mean' or the sum of the - squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default - is to use 'sum' for all tables (optional). - mode_override: A string input that overrides the mode specified in the - TPUEmbeddingConfiguration. Supported values are {'unspecified', - 'inference', 'training', 'backward_pass_only'}. When set to - 'unspecified', the mode set in TPUEmbeddingConfiguration is used, - otherwise mode_override is used (optional). - name: A name for the operation (optional). - - Returns: - An EnqueueTPUEmbeddingSparseBatch operation. - """ - if mode_override is None: - mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_sparse_batch( - sample_indices=sample_indices, - embedding_indices=embedding_indices, - aggregation_weights=aggregation_weights, - device_ordinal=device_ordinal, - combiners=combiners, - mode_override=mode_override, - name=name) - - enqueue_tpu_embedding_sparse_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_sparse_batch.__doc__) - - # pylint: disable=protected-access - def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices, - embedding_indices, - aggregation_weights, - table_ids, - device_ordinal, - combiners=None, - mode_override=None, - name=None): - """A placeholder op for enqueueing embedding IDs to the TPU. - - Args: - sample_indices: A list of rank 1 Tensors specifying the training example - to which the corresponding embedding_indices and aggregation_weights - values - belong. It corresponds to sp_ids.indices[:,0] in - embedding_lookup_sparse(). - embedding_indices: A list of rank 1 Tensors, indices into the embedding - tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). - aggregation_weights: A list of rank 1 Tensors containing per training - example aggregation weights. It corresponds to sp_weights.values in - embedding_lookup_sparse(). - table_ids: A list of integers specifying the identifier of the embedding - table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to - lookup the corresponding input. The ith input is looked up using - table_ids[i]. The size of the table_ids list must be equal to that of - sample_indices, embedding_indices and aggregation_weights. - device_ordinal: The TPU device to use. Should be >= 0 and less than the - number of TPU cores in the task on which the node is placed. - combiners: A list of string scalars, one for each embedding table that - specify how to normalize the embedding activations after weighted - summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is - invalid to have the sum of the weights be 0 for 'mean' or the sum of the - squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default - is to use 'sum' for all tables (optional). - mode_override: A string input that overrides the mode specified in the - TPUEmbeddingConfiguration. Supported values are {'unspecified', - 'inference', 'training', 'backward_pass_only'}. When set to - 'unspecified', the mode set in TPUEmbeddingConfiguration is used, - otherwise mode_override is used (optional). - name: A name for the operation (optional). - - Returns: - An EnqueueTPUEmbeddingSparseTensorBatch operation. - """ - if mode_override is None: - mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_sparse_tensor_batch( - sample_indices=sample_indices, - embedding_indices=embedding_indices, - aggregation_weights=aggregation_weights, - table_ids=table_ids, - device_ordinal=device_ordinal, - combiners=combiners, - mode_override=mode_override, - name=name) - - enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_sparse_tensor_batch.__doc__) - -else: - # We have already built the appropriate libraries into the binary via CMake - # if we have built contrib, so we don't need this - pass +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.ops.tpu_ops import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py b/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py new file mode 100644 index 0000000000000000000000000000000000000000..788e1fe0568cf2f406c379e4d928100ea51a37a3 --- /dev/null +++ b/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py @@ -0,0 +1,23 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.ops.tpu_ordinal_selector_op import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/profiler/__init__.py b/tensorflow/contrib/tpu/python/profiler/__init__.py index 15ce6aceec299adacd7025f0021cf8b6f6ef765b..aeb061dbe114bc287946b50d08a86778c78c7b38 100644 --- a/tensorflow/contrib/tpu/python/profiler/__init__.py +++ b/tensorflow/contrib/tpu/python/profiler/__init__.py @@ -1,31 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Classes for TPU trace events.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import,unused-import -from tensorflow.contrib.tpu.profiler.tpu_profiler_analysis_pb2 import * -from tensorflow.contrib.tpu.profiler.trace_events_pb2 import * +from tensorflow.python.tpu.profiler import * # pylint: enable=wildcard-import,unused-import - -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = ['Trace', 'Resource', 'Device', 'TraceEvent'] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/tpu/python/tpu/__init__.py b/tensorflow/contrib/tpu/python/tpu/__init__.py index 0dffd7064b19f353aed6afa3ad383564643a4a90..82d4f68c0221013706f70bcf54ae4c97cc7db1d3 100644 --- a/tensorflow/contrib/tpu/python/tpu/__init__.py +++ b/tensorflow/contrib/tpu/python/tpu/__init__.py @@ -1,20 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Ops related to Tensor Processing Units.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function + +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py b/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..41aa4d267812cabe775459723df7e01efaa83c93 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py @@ -0,0 +1,23 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu._tpu_estimator_embedding import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py index 1b09ce173a64ba3f93ec019c8fd65dc4710f0fcf..5eb8034e47474873ccef0b6123f2becd0668738c 100644 --- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py +++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py @@ -1,212 +1,23 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the 'License'); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, +# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Hook for asynchronous checkpointing. - -This hook dispatches checkpoint writing operations in a separate thread to -allow execution to continue on the main thread. -""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import threading -import time - -from tensorflow.core.util.event_pb2 import SessionLog -from tensorflow.python.framework import meta_graph -from tensorflow.python.framework import ops -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import training_util -from tensorflow.python.training.session_run_hook import SessionRunArgs -from tensorflow.python.training.summary_io import SummaryWriterCache - - -class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): - """Saves checkpoints every N steps or seconds.""" - - def __init__(self, - checkpoint_dir, - save_secs=None, - save_steps=None, - saver=None, - checkpoint_basename="model.ckpt", - scaffold=None, - listeners=None): - """Initializes a `CheckpointSaverHook`. - - Args: - checkpoint_dir: `str`, base directory for the checkpoint files. - save_secs: `int`, save every N secs. - save_steps: `int`, save every N steps. - saver: `Saver` object, used for saving. - checkpoint_basename: `str`, base name for the checkpoint files. - scaffold: `Scaffold`, use to get saver object. - listeners: List of `CheckpointSaverListener` subclass instances. Used for - callbacks that run immediately before or after this hook saves the - checkpoint. - - Raises: - ValueError: One of `save_steps` or `save_secs` should be set. - ValueError: At most one of `saver` or `scaffold` should be set. - """ - logging.info("Create AsyncCheckpointSaverHook.") - if saver is not None and scaffold is not None: - raise ValueError("You cannot provide both saver and scaffold.") - self._saver = saver - self._save_thread = None - self._write_graph_thread = None - self._checkpoint_dir = checkpoint_dir - self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) - self._scaffold = scaffold - self._timer = basic_session_run_hooks.SecondOrStepTimer( - every_secs=save_secs, every_steps=save_steps) - self._listeners = listeners or [] - self._steps_per_run = 1 - self._summary_writer = None - self._global_step_tensor = None - - self._last_checkpoint_step = None - - def _set_steps_per_run(self, steps_per_run): - self._steps_per_run = steps_per_run - - def begin(self): - self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) - self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access - if self._global_step_tensor is None: - raise RuntimeError( - "Global step should be created to use CheckpointSaverHook.") - for l in self._listeners: - l.begin() - - def after_create_session(self, session, coord): - global_step = session.run(self._global_step_tensor) - - # We do write graph and saver_def at the first call of before_run. - # We cannot do this in begin, since we let other hooks to change graph and - # add variables in begin. Graph is finalized after all begin calls. - def _write_graph_fn(self): - training_util.write_graph( - ops.get_default_graph().as_graph_def(add_shapes=True), - self._checkpoint_dir, "graph.pbtxt") - self._write_graph_thread = threading.Thread(target=_write_graph_fn, - args=[self]) - self._write_graph_thread.start() - - saver_def = self._get_saver().saver_def if self._get_saver() else None - graph = ops.get_default_graph() - meta_graph_def = meta_graph.create_meta_graph_def( - graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) - self._summary_writer.add_graph(graph) - self._summary_writer.add_meta_graph(meta_graph_def) - # The checkpoint saved here is the state at step "global_step". - self._save(session, global_step) - self._timer.update_last_triggered_step(global_step) - - def before_run(self, run_context): # pylint: disable=unused-argument - return SessionRunArgs(self._global_step_tensor) - - def after_run(self, run_context, run_values): - global_step = run_context.session.run(self._global_step_tensor) - if self._timer.should_trigger_for_step(global_step): - self._timer.update_last_triggered_step(global_step) - logging.info("Triggering checkpoint. %s", global_step) - if self._save(run_context.session, global_step): - run_context.request_stop() - - def end(self, session): - if self._save_thread: - logging.info("Waiting for any pending checkpoints to finish.") - self._save_thread.join() - if self._write_graph_thread: - logging.info("Waiting for any pending write_graph to finish.") - self._write_graph_thread.join() - - last_step = session.run(self._global_step_tensor) - - if self._last_checkpoint_step != last_step: - self._save(session, last_step, asynchronous=False) - - for l in self._listeners: - l.end(session, last_step) - - def _save(self, session, step, asynchronous=True): - """Saves the latest checkpoint, returns should_stop.""" - - # Skip saving on step 0 - if step == 0: - return - - def _save_fn(): - """Run the saver process.""" - logging.info("Saving checkpoints for %d into %s.", step, self._save_path) - - start_time = time.time() - for l in self._listeners: - l.before_save(session, step) - - self._get_saver().save(session, self._save_path, global_step=step) - self._summary_writer.add_session_log( - SessionLog( - status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), - step) - - for l in self._listeners: - l.after_save(session, step) - - end_time = time.time() - logging.info("Checkpoint actual writing time: (%.3f sec)", - end_time - start_time) - logging.info("Checkpoint finished for %d into %s.", step, self._save_path) - - if not asynchronous: - self._last_checkpoint_step = step - _save_fn() - return - - if self._save_thread is not None: - self._save_thread.join(timeout=0.1) - if self._save_thread.is_alive(): - logging.info("Saver thread still in progress, skipping checkpoint.") - return - - self._last_checkpoint_step = step - self._save_thread = threading.Thread(target=_save_fn) - self._save_thread.start() - - def _get_saver(self): - if self._saver is not None: - return self._saver - elif self._scaffold is not None: - return self._scaffold.saver - - # Get saver from the SAVERS collection if present. - collection_key = ops.GraphKeys.SAVERS - savers = ops.get_collection(collection_key) - if not savers: - raise RuntimeError( - "No items in collection {}. Please add a saver to the collection " - "or provide a saver or scaffold.".format(collection_key)) - elif len(savers) > 1: - raise RuntimeError( - "More than one item in collection {}. " - "Please indicate which one to use by passing it to the constructor." - .format(collection_key)) - - self._saver = savers[0] - return savers[0] +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.async_checkpoint import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/bfloat16.py b/tensorflow/contrib/tpu/python/tpu/bfloat16.py index fa74f651aa63c72d14eb78c8af479263810e9b7d..f3d392a8daec2a80f974d90051324a02be002afd 100644 --- a/tensorflow/contrib/tpu/python/tpu/bfloat16.py +++ b/tensorflow/contrib/tpu/python/tpu/bfloat16.py @@ -1,77 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Helper context for running models with bfloat16.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.util import tf_contextlib - - -def _get_custom_getter(): - """Returns a custom getter that this class's methods must be called under. - - All methods of this class must be called under a variable scope that was - passed this custom getter. Example: - - ```python - network = ConvNetBuilder(...) - with tf.variable_scope('cg', custom_getter=network.get_custom_getter()): - network.conv(...) - # Call more methods of network here - ``` - - Currently, this custom getter only does anything if self.use_tf_layers is - True. In that case, it causes variables to be stored as dtype - self.variable_type, then casted to the requested dtype, instead of directly - storing the variable as the requested dtype. - """ - - def inner_custom_getter(getter, *args, **kwargs): - """Custom getter that forces variables to have type self.variable_type.""" - cast_to_bfloat16 = False - requested_dtype = kwargs['dtype'] - if requested_dtype == dtypes.bfloat16: - # Only change the variable dtype if doing so does not decrease variable - # precision. - kwargs['dtype'] = dtypes.float32 - cast_to_bfloat16 = True - var = getter(*args, **kwargs) - # This if statement is needed to guard the cast, because batch norm - # assigns directly to the return value of this custom getter. The cast - # makes the return value not a variable so it cannot be assigned. Batch - # norm variables are always in fp32 so this if statement is never - # triggered for them. - if cast_to_bfloat16: - var = math_ops.cast(var, dtypes.bfloat16) - return var - - return inner_custom_getter - - -@tf_contextlib.contextmanager -def bfloat16_scope(): - """Scope class for bfloat16 variables so that the model uses custom getter. - - This enables variables to be read as bfloat16 type when using get_variable. - """ - with variable_scope.variable_scope( - '', custom_getter=_get_custom_getter()) as varscope: - yield varscope +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.bfloat16 import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py index 8d6245390fc3fa005c92d01bc9b64ddb47583582..c20aac7e36aa31c5a9d88ca6fe02a8703f9ed5a3 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -1,194 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Library of Cloud TPU helper functions for data loading.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.experimental.ops import batching -from tensorflow.python.data.experimental.ops import interleave_ops -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.data.ops import readers -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.ops import functional_ops - - -def _TextLineDataset(filename): - buffer_size = 8 * 1024 * 1024 # 8 MiB per file - dataset = readers.TextLineDataset(filename, buffer_size=buffer_size) - return dataset - - -def _TFRecordDataset(filename): - buffer_size = 8 * 1024 * 1024 # 8 MiB per file - dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size) - return dataset - - -_FILETYPE_MAP = { - 'tfrecord': _TFRecordDataset, - 'textline': _TextLineDataset, - 'text': _TextLineDataset, -} - - -def StreamingFilesDataset(files, - filetype=None, - file_reader_job=None, - worker_job=None, - num_epochs=None, - filename_shuffle_buffer_size=None, - num_parallel_reads=None, - batch_transfer_size=None, - sloppy=None): - """StreamingFilesDataset constructs a dataset to stream from workers (GCE VM). - - Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read - files local to your GCE VM. In order to train using files stored on your local - VM (e.g. on local SSD for extreme performance), use the StreamingFilesDataset - helper to generate a dataset to feed your Cloud TPU with files from your GCE - VM. - - The resulting dataset may return an OutOfRangeError if there are no files - found as a result of the fileglob expansion. - - Note: StreamingFilesDataset assumes that the session is using a - TPUClusterResolver and has therefore a worker and a coordinator job. File - loading will be done on the coordinator job. - - Args: - files: A string glob to match files, or a `tf.data.Dataset` generating file - names. - filetype: A string (one of 'tfrecord', or 'textline') or a single-argument - TensorFlow function that when given a filename returns a dataset. - file_reader_job: An optional string that corresponds to the job that should - perform the file reads. - worker_job: An optional string that corresponds to the job that should - process the tensors (i.e. your GPU or TPU worker). - num_epochs: The number of epochs through the training set that should be - generated. By default, it will repeat infinitely. - filename_shuffle_buffer_size: An optional integer whose value controls the - shuffling of the file names. If you would like to read from the files in - the same order, set to 0 or False. - num_parallel_reads: An optional integer controlling the number of files to - read from concurrently. (Set to 1 for no parallelism.) - batch_transfer_size: An optional integer controlling the batching used to - amortize the remote function invocation overhead. Set to a very large - number to increase throughput. Set to a very small number to reduce memory - consumption. Set to False to skip batching. - sloppy: (Optional.) If `False`, read input data while maintaining a - deterministic order. (This may have significant performance impacts.) - sloppy defaults to: True. - Returns: - A `tf.data.Dataset` with an infinite stream of elements generated by a - parallel interleaving of the set of files matched (or generated) by `files` - with a type is the output of the dataset specified by `filetype`. - - Raises: - ValueError: if any argument is not of the expected type. - """ - if filetype is None: - filetype = 'tfrecord' - - if isinstance(filetype, str): - if filetype not in _FILETYPE_MAP: - raise ValueError('Unexpected filetype: %s' % filetype) - reader_fn = _FILETYPE_MAP[filetype] - elif callable(filetype): - reader_fn = filetype - else: - raise ValueError('filetype should be a string or a callable') - - file_reader_job = file_reader_job or 'coordinator' - - worker_job = worker_job or 'worker' - - if filename_shuffle_buffer_size is None: - filename_shuffle_buffer_size = 4096 - - num_parallel_reads = num_parallel_reads or 8 - - if batch_transfer_size is None: - batch_transfer_size = 256 - - if sloppy is None: - sloppy = True - - with ops.device('/job:%s' % file_reader_job): - if isinstance(files, str): - source_dataset = dataset_ops.Dataset.list_files(files) - elif isinstance(files, dataset_ops.DatasetV2): - source_dataset = files - else: - raise ValueError('files was not a string or a dataset: %s' % files) - - if filename_shuffle_buffer_size: - source_dataset = source_dataset.shuffle( - buffer_size=filename_shuffle_buffer_size) - - # NOTE: We perform the `repeat` on the source dataset, because the output - # dataset does not currently have enough information to recreate an iterator - # over the source dataset when it reaches the end. - source_dataset = source_dataset.repeat(num_epochs) - - source_dataset = source_dataset.apply( - interleave_ops.parallel_interleave( - reader_fn, cycle_length=num_parallel_reads, sloppy=sloppy)) - - if batch_transfer_size: - source_dataset = source_dataset.batch(batch_transfer_size) - - source_dataset = source_dataset.prefetch(1) - - source_iterator = dataset_ops.make_one_shot_iterator(source_dataset) - source_handle = source_iterator.string_handle() - - @function.Defun(dtypes.string) - def LoadingFunc(h): - remote_iterator = iterator_ops.Iterator.from_string_handle( - h, source_dataset.output_types, source_dataset.output_shapes) - return remote_iterator.get_next() - - def MapFn(unused_input): - if isinstance(source_dataset.output_types, dtypes.DType): - output_types = [source_dataset.output_types] - elif isinstance(source_dataset.output_types, (list, tuple)): - output_types = source_dataset.output_types - else: - raise ValueError('source dataset has invalid output types') - remote_calls = functional_ops.remote_call( - args=[source_handle], - Tout=output_types, - f=LoadingFunc, - target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job) - if len(remote_calls) == 1: - return remote_calls[0] - else: - return remote_calls - - with ops.device('/job:%s' % worker_job): - output_dataset = dataset_ops.Dataset.range(2).repeat().map( - MapFn, num_parallel_calls=4 if sloppy else None) - output_dataset = output_dataset.prefetch(1) - - if batch_transfer_size: - # Undo the batching used during the transfer. - output_dataset = output_dataset.apply(batching.unbatch()).prefetch(1) - - return output_dataset +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.datasets import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/device_assignment.py b/tensorflow/contrib/tpu/python/tpu/device_assignment.py index 6906501ecf90c8e577aa0becf2dba818deb19df4..05dffef3a1efdae2ad7306ca5ad3bc7a9eac04cf 100644 --- a/tensorflow/contrib/tpu/python/tpu/device_assignment.py +++ b/tensorflow/contrib/tpu/python/tpu/device_assignment.py @@ -1,310 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Library of TPU helper functions.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.contrib.tpu.python.tpu.topology import Topology - - -def _compute_task_and_cores_to_replicas(core_assignment, topology): - """Computes a nested dict which maps task and logical core to replicas.""" - task_and_cores_to_replicas = {} - for replica in xrange(core_assignment.shape[0]): - for logical_core in xrange(core_assignment.shape[1]): - coordinates = core_assignment[replica, logical_core, :] - task_id = topology.task_ordinal_at_coordinates(coordinates) - if task_id not in task_and_cores_to_replicas: - task_and_cores_to_replicas[task_id] = {} - if logical_core not in task_and_cores_to_replicas[task_id]: - task_and_cores_to_replicas[task_id][logical_core] = set() - - task_and_cores_to_replicas[task_id][logical_core].add(replica) - - task_to_sorted_replica_id = {} - - for task, core_to_replicas in task_and_cores_to_replicas.items(): - core_to_sorted_replicas = {} - for core, replicas in core_to_replicas.items(): - core_to_sorted_replicas[core] = sorted(replicas) - - task_to_sorted_replica_id[task] = core_to_sorted_replicas - return task_to_sorted_replica_id - - -class DeviceAssignment(object): - """Mapping from logical cores in a computation to the physical TPU topology. - - Prefer to use the `device_assignment()` helper to construct a - `DeviceAssignment`; it is easier if less flexible than constructing a - `DeviceAssignment` directly. - """ - - def __init__(self, topology, core_assignment): - """Constructs a `DeviceAssignment` object. - - Args: - topology: A `Topology` object that describes the physical TPU topology. - core_assignment: A logical to physical core mapping, represented as a - rank 3 numpy array. See the description of the `core_assignment` - property for more details. - - Raises: - ValueError: If `topology` is not `Topology` object. - ValueError: If `core_assignment` is not a rank 3 numpy array. - """ - if not isinstance(topology, Topology): - raise ValueError("topology must be a Topology object, got {}".format( - type(topology))) - core_assignment = np.asarray(core_assignment, dtype=np.int32) - - self._topology = topology - - if core_assignment.ndim != 3: - raise ValueError("core_assignment must be a rank 3 numpy array, " - "got shape {}".format(core_assignment.shape)) - - self._num_replicas = core_assignment.shape[0] - self._num_cores_per_replica = core_assignment.shape[1] - - if core_assignment.shape[-1] != topology.mesh_rank: - raise ValueError( - "minor dimension of core_assignment must have size equal to topology " - "rank ({}), got shape {}".format(topology.mesh_rank, - core_assignment.shape)) - - self._core_assignment = core_assignment - self._task_and_cores_to_replicas = _compute_task_and_cores_to_replicas( - self._core_assignment, topology) - - @property - def topology(self): - """A `Topology` that describes the TPU topology.""" - return self._topology - - @property - def num_cores_per_replica(self): - """The number of cores per replica.""" - return self._num_cores_per_replica - - @property - def num_replicas(self): - """The number of replicas of the computation.""" - return self._num_replicas - - @property - def core_assignment(self): - """The logical to physical core mapping. - - Returns: - An integer numpy array of rank 3, with shape - `[num_replicas, num_cores_per_replica, topology_rank]`. Maps - (replica, logical core) pairs to physical topology coordinates. - """ - return self._core_assignment - - def _coordinates(self, replica, logical_core): - """Returns the physical topology coordinates of a logical core.""" - return tuple(self.core_assignment[replica, logical_core, :]) - - def lookup_replicas(self, task_id, logical_core): - """Lookup replica ids by task number and logical core. - - Args: - task_id: TensorFlow task number. - logical_core: An integer, identifying a logical core. - Returns: - A sorted list of the replicas that are attached to that task and - logical_core. - Raises: - ValueError: If no replica exists in the task which contains the logical - core. - """ - try: - return self._task_and_cores_to_replicas[task_id][logical_core] - except KeyError: - raise ValueError( - "Can not find any replica in task: {} contains logical_core: {} ". - format(task_id, logical_core)) - - def tpu_ordinal(self, replica=0, logical_core=0): - """Returns the ordinal of the TPU device assigned to a logical core.""" - coordinates = self._coordinates(replica, logical_core) - return self._topology.tpu_device_ordinal_at_coordinates(coordinates) - - def host_device(self, replica=0, logical_core=0, job=None): - """Returns the CPU device attached to a logical core.""" - coordinates = self._coordinates(replica, logical_core) - return self._topology.cpu_device_name_at_coordinates(coordinates, job=job) - - def tpu_device(self, replica=0, logical_core=0, job=None): - """Returns the name of the TPU device assigned to a logical core.""" - coordinates = self._coordinates(replica, logical_core) - return self._topology.tpu_device_name_at_coordinates(coordinates, job=job) - - -def device_assignment(topology, - computation_shape=None, - computation_stride=None, - num_replicas=1): - """Computes a device_assignment of a computation across a TPU topology. - - Attempts to choose a compact grid of cores for locality. - - Returns a `DeviceAssignment` that describes the cores in the topology assigned - to each core of each replica. - - `computation_shape` and `computation_stride` values should be powers of 2 for - optimal packing. - - Args: - topology: A `Topology` object that describes the TPU cluster topology. - To obtain a TPU topology, evaluate the `Tensor` returned by - `initialize_system` using `Session.run`. Either a serialized - `TopologyProto` or a `Topology` object may be passed. Note: you must - evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here. - computation_shape: A rank 1 int32 numpy array with size equal to the - topology rank, describing the shape of the computation's block of cores. - If None, the `computation_shape` is `[1] * topology_rank`. - computation_stride: A rank 1 int32 numpy array of size `topology_rank`, - describing the inter-core spacing of the `computation_shape` cores in the - TPU topology. If None, the `computation_stride` is `[1] * topology_rank`. - num_replicas: The number of computation replicas to run. The replicas will - be packed into the free spaces of the topology. - - Returns: - A DeviceAssignment object, which describes the mapping between the logical - cores in each computation replica and the physical cores in the TPU - topology. - - Raises: - ValueError: If `topology` is not a valid `Topology` object. - ValueError: If `computation_shape` or `computation_stride` are not 1D int32 - numpy arrays with shape [3] where all values are positive. - ValueError: If computation's replicas cannot fit into the TPU topology. - """ - # Deserialize the Topology proto, if it is a string. - if isinstance(topology, bytes): - topology = Topology(serialized=topology) - - if not isinstance(topology, Topology): - raise ValueError("`topology` is not a Topology object; got {}".format( - type(topology))) - - topology_rank = len(topology.mesh_shape) - mesh_shape = topology.mesh_shape - if computation_shape is None: - computation_shape = np.array([1] * topology_rank, dtype=np.int32) - else: - computation_shape = np.asarray(computation_shape, dtype=np.int32) - - if computation_stride is None: - computation_stride = np.array([1] * topology_rank, dtype=np.int32) - else: - computation_stride = np.asarray(computation_stride, dtype=np.int32) - - if computation_shape.shape != (topology_rank,): - raise ValueError("computation_shape must have shape [{}]; got {}".format( - topology_rank, computation_shape.shape)) - if computation_stride.shape != (topology_rank,): - raise ValueError("computation_stride must have shape [{}]; got {}".format( - topology_rank, computation_stride.shape)) - - if any(computation_shape < 1): - raise ValueError( - "computation_shape must be positive; got computation_shape={}".format( - computation_shape)) - if any(computation_stride < 1): - raise ValueError( - "computation_stride must be positive; got computation_stride={}".format( - computation_stride)) - - # Computes the physical size of one computation instance. - computation_footprint = computation_shape * computation_stride - if any(computation_footprint > mesh_shape): - raise ValueError( - "computation footprint {} does not fit in TPU topology shape {}".format( - computation_footprint, mesh_shape)) - - # Computes how many copies of the computation footprint fit in the mesh. - block_counts = mesh_shape // computation_footprint - - replica_counts = block_counts * computation_stride - max_replicas = np.prod(replica_counts) - if num_replicas > max_replicas: - raise ValueError( - "requested {} replicas but only {} replicas with shape {} and " - "computation_stride {} fit in a TPU mesh of shape {}".format( - num_replicas, max_replicas, computation_shape, computation_stride, - mesh_shape)) - - def ceil_of_ratio(n, m): - return (n + m - 1) // m - - replica_shape = [0] * topology_rank - if num_replicas > 0: - remaining_replicas = num_replicas - remaining_dims = topology_rank - - # Choose dimensions as close to an equal cube as possible, in order of - # increasing dimension size. By visiting dimensions in increasing size, we - # assign the most constrained dimension first, so we won't make infeasible - # choices. - # - # As a secondary sort order, visit the dimensions in reverse order. This - # means we try to use both cores on the same chip in preference to two cores - # on different chips. - for x, ni in sorted(((x, -i) for (i, x) in enumerate(replica_counts))): - i = -ni - target_size = int(math.ceil(remaining_replicas**(1.0 / remaining_dims))) - replica_shape[i] = min(target_size, x) - remaining_replicas = ceil_of_ratio(remaining_replicas, replica_shape[i]) - remaining_dims -= 1 - - assert remaining_replicas == 1 and remaining_dims == 0 - - # Assigns an offset to each replica such that no two replicas overlap. - replica_offsets = np.full([num_replicas, topology_rank], -1, dtype=np.int32) - for replica in xrange(num_replicas): - # Chooses a replica number in each axis. - t = replica - pos = [] - for dim in replica_shape[::-1]: - pos.append(t % dim) - t //= dim - replica_pos = np.array(pos[::-1], dtype=np.int32) - - # Determines where that replica starts in each axis. - outer = replica_pos // computation_stride - inner = replica_pos % computation_stride - replica_offsets[replica, :] = outer * computation_footprint + inner - - # Computes a complete logical core -> physical core mapping for each replica. - indices = [ - np.arange(0, computation_shape[i] * computation_stride[i], - computation_stride[i]) for i in xrange(topology_rank) - ] - indices = np.concatenate( - [i[..., np.newaxis] for i in np.meshgrid(*indices, indexing="ij")], - axis=-1) - indices = indices.reshape((-1, topology_rank)) - assignment = indices + replica_offsets[:, np.newaxis, :] - return DeviceAssignment(topology, core_assignment=assignment) +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.device_assignment import * +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/error_handling.py b/tensorflow/contrib/tpu/python/tpu/error_handling.py index 52e1ea42370d653d1de7c12eee4b456ec7ce921c..1b1328b4075d9a737e40693c13e33e0b7c1fbedf 100644 --- a/tensorflow/contrib/tpu/python/tpu/error_handling.py +++ b/tensorflow/contrib/tpu/python/tpu/error_handling.py @@ -1,132 +1,23 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""ErrorRendezvous handler for collecting errors from multiple threads.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib -import sys -import threading -import time - -import six - -from tensorflow.python.framework import errors -from tensorflow.python.platform import tf_logging as logging - -_UNINTERESTING_ERRORS = (errors.CancelledError,) - - -class ErrorRendezvous(object): - """Resolve errors from multiple threads during TPU execution. - - TPU errors can occur on the infeed or outfeed threads as well as the main - training thread. - - Depending on which thread "wins" and receives the session error first, we may - end up showing users a confusing and non-actionable error message (session - cancelled) instead of a root cause (e.g. a bad filename). - - The rendezvous object provides a location to capture these errors until all - threads terminate. At that point we can choose the most informative error - to report. - """ - - def __init__(self, num_sources): - # string -> (message, traceback) - self._errors = {} - self._num_sources = num_sources - self._session_cancel_timer = None - - def record_error(self, source, exc_info, session=None): - """Report an exception from the given source. - - If a session is passed, a timer will be registered to close it after a few - seconds. This is necessary to ensure the main training loop does not hang - if an infeed/oufeed error occurs. We sleep a few seconds to allow a more - interesting error from another thread to propagate. - - Args: - source: string, source of the error - exc_info: Output from `sys.exc_info` (type, value, traceback) - session: Session to close after delay. - """ - _, value, _ = exc_info - self._errors[source] = exc_info - logging.info('Error recorded from %s: %s', source, value) - - if session is not None and self._session_cancel_timer is None: - - def _cancel_session(): - time.sleep(5) - try: - session.close() - except: # pylint: disable=bare-except - pass - - self._session_cancel_timer = threading.Thread(target=_cancel_session,) - self._session_cancel_timer.daemon = True - self._session_cancel_timer.start() - - def record_done(self, source): - """Mark execution source `source` as done. - - If an error was originally reported from `source` it is left intact. - - Args: - source: `str`, source being recorded - """ - logging.info('%s marked as finished', source) - if source not in self._errors: - self._errors[source] = None - - @contextlib.contextmanager - def catch_errors(self, source, session=None): - """Context manager to report any errors within a block.""" - try: - yield - except Exception: # pylint: disable=broad-except - self.record_error(source, sys.exc_info(), session) - - def raise_errors(self, timeout_sec=0): - """Wait for up to `timeout` seconds for all error sources to finish. - - Preferentially raise "interesting" errors (errors not in the - _UNINTERESTING_ERRORS) set. - - Args: - timeout_sec: Seconds to wait for other error sources. - """ - for _ in range(timeout_sec): - if len(self._errors) == self._num_sources: - break - time.sleep(1) - - kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None] - - # First check for any interesting errors, then fall back on the session - # cancelled errors etc. - for k, (typ, value, traceback) in kept_errors: - if isinstance(value, _UNINTERESTING_ERRORS): - continue - else: - logging.warn('Reraising captured error') - six.reraise(typ, value, traceback) - - for k, (typ, value, traceback) in kept_errors: - logging.warn('Reraising captured error') - six.reraise(typ, value, traceback) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.error_handling import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/feature_column.py b/tensorflow/contrib/tpu/python/tpu/feature_column.py new file mode 100644 index 0000000000000000000000000000000000000000..ded75e975b10c4265370af260bf804687c9caebc --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/feature_column.py @@ -0,0 +1,30 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.feature_column import * +# used by tests +from tensorflow.python.tpu.feature_column import _is_running_on_cpu +from tensorflow.python.tpu.feature_column import _record_variable_scope_and_name +from tensorflow.python.tpu.feature_column import _TPU_FC_TO_SCOPE +from tensorflow.python.tpu.feature_column import _TPUBaseEmbeddingColumn +from tensorflow.python.tpu.feature_column import _TPUEmbeddingColumn +from tensorflow.python.tpu.feature_column import _TPUSharedEmbeddingColumn +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/functional.py b/tensorflow/contrib/tpu/python/tpu/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5759221ed9660200cc213df69961db56f8d490 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/functional.py @@ -0,0 +1,23 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.functional import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 4ce194590342555a7c4e9e119bf51e516a37a715..6ad4e45e9625f191bb4c01f70b434dc2c4fba638 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -55,8 +55,6 @@ import numpy as np import six from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib -from tensorflow.contrib.framework.python.framework import experimental -from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import keras_tpu_variables from tensorflow.contrib.tpu.python.tpu import tpu @@ -64,6 +62,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.contrib.tpu.python.tpu import tpu_optimizer from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf.tpu import compilation_result_pb2 as tpu_compilation_result from tensorflow.python.client import session as tf_session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops @@ -94,6 +93,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.deprecation import deprecated # TODO(b/114775106): temporary shim to optionally initialize the TPU @@ -1373,6 +1373,10 @@ class KerasTPUModel(models.Model): # not hashable. self._numpy_to_infeed_manager_list = [] + # Add distribution specific arguments since we don't call the Model init. + self._distribution_strategy = None + self._compile_distribution = None + self.predict_function = None self.test_function = None self.train_function = None @@ -2069,6 +2073,8 @@ class KerasTPUModel(models.Model): # tpu_model may not be compiled, e.g., loading weights and then predict. return for k, v in six.iteritems(cpu_optimizer_config): + if k == 'name': + continue opt_var = getattr(self._tpu_model.optimizer, k) if isinstance(opt_var, variables.Variable): logging.info('CPU -> TPU %s: %s {%s}', k, v, K.get_value(opt_var)) @@ -2097,6 +2103,8 @@ class KerasTPUModel(models.Model): self._cpu_model.set_weights(tpu_weights) for k, v in six.iteritems(tpu_optimizer_config): logging.info('TPU -> CPU %s: %s', k, v) + if k == 'name': + continue opt_var = getattr(self.cpu_optimizer, k) if isinstance(opt_var, variables.Variable): K.get_session().run(opt_var.assign(v)) @@ -2164,7 +2172,10 @@ Output shape: %(output_shape)s # pylint: enable=bad-continuation -@experimental +@deprecated( + '2019-02-20', 'Switch to tf.contrib.distribute.TPUStrategy. ' + 'https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy' +) def tpu_model(model, strategy=None): """Copy `model` along with weights to the TPU. diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py index 8b0b240dc7302c203a22349d583323327fc4480b..de425626c813784ef657d17eac0c7bb77599a155 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py @@ -69,6 +69,7 @@ class ReplicatedVariable(object): def __init__(self, name, variables): self._name = name self._primary_var = variables[0] + self._common_name = self._primary_var.name.split(":")[0] self._vars = variables self._cached_value = None self._dtype = variables[0].dtype diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py index 3e463823c820a3ef8628324f77e1a9caf8d385d5..ed8f9525c9b91208d39805654b01837abdbf3a77 100644 --- a/tensorflow/contrib/tpu/python/tpu/session_support.py +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -1,433 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the 'License'); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, +# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Operations for handling session logging and shutdown notifications.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import threading - -import time -from google.protobuf import text_format - -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.util import event_pb2 -from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import session_run_hook -from tensorflow.python.training import training_util - -_WATCHDOG = None - - -class CoordinatorShutdownException(Exception): - """Raised when the coordinator needs to shutdown.""" - pass - - -def _clone_session(session, graph=None): - return session_lib.Session( - target=session.sess_str, - config=session._config, # pylint: disable=protected-access - graph=graph if graph else session.graph) - - -def _make_heartbeat_op(session, device, request_ph): - """Return a heartbeat op or None if heartbeats are not supported by device.""" - try: - # Test if we can connect in a isolated graph + session - with ops.Graph().as_default(): - with _clone_session(session) as temp_session: - with ops.device(device): - heartbeat_op = tpu_ops.worker_heartbeat('') - options = config_pb2.RunOptions(timeout_in_ms=5000) - temp_session.run(heartbeat_op, options=options) - except errors.InvalidArgumentError as _: - logging.warning('Error running heartbeat on %s', device) - return None - except errors.DeadlineExceededError as _: - logging.warning('Timeout connecting to %s when testing heartbeat', device) - return None - - # If we successfully connected and pinged the worker, go ahead and construct - # the operation. - with ops.device(device): - return tpu_ops.worker_heartbeat(request_ph) - - -class WorkerHeartbeatManager(object): - """Manages the status/heartbeat monitor for a set of workers.""" - - def __init__(self, session, devices, heartbeat_ops, request_placeholder): - """Construct a new WorkerHeartbeatManager. - - (Prefer using `WorkerHeartbeatManager.from_devices` when possible.) - - Args: - session: `tf.Session`, session to use for heartbeat operations. - devices: `list[string]` Set of devices to connect to. - heartbeat_ops: `list[tf.Operation]` Heartbeat operations. - request_placeholder: `tf.Placeholder[String]` Placeholder used to specify - the WorkerHeartbeatRequest protocol buffer. - """ - self._session = session - self._devices = devices - self._ops = heartbeat_ops - self._request_placeholder = request_placeholder - - @staticmethod - def from_devices(session, devices): - """Construct a heartbeat manager for the given devices.""" - if not devices: - logging.error('Trying to create heartbeat manager with no devices?') - - logging.info('Creating heartbeat manager for %s', devices) - request_placeholder = array_ops.placeholder( - name='worker_heartbeat_request', dtype=dtypes.string) - - heartbeat_ops = [] - kept_devices = [] - for device in devices: - heartbeat_op = _make_heartbeat_op(session, device, request_placeholder) - if heartbeat_op is not None: - kept_devices.append(device) - heartbeat_ops.append(heartbeat_op) - else: - logging.warning('Heartbeat support not available for %s', device) - - return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops, - request_placeholder) - - def num_workers(self): - return len(self._devices) - - def configure(self, message): - """Configure heartbeat manager for all devices. - - Args: - message: `event_pb2.WorkerHeartbeatRequest` - Returns: `None` - """ - logging.info('Configuring worker heartbeat: %s', - text_format.MessageToString(message)) - self._session.run(self._ops, - {self._request_placeholder: message.SerializeToString()}) - - def ping(self, request=None, timeout_in_ms=5000): - """Ping all workers, returning the parsed status results.""" - if request is None: - request = event_pb2.WorkerHeartbeatRequest() - - options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms) - results = self._session.run( - self._ops, - feed_dict={self._request_placeholder: request.SerializeToString()}, - options=options) - parsed_results = [ - event_pb2.WorkerHeartbeatResponse.FromString(res_pb) - for res_pb in results - ] - logging.debug('Ping results: %s', parsed_results) - return parsed_results - - def lame_workers(self): - """Ping all workers, returning manager containing lame workers (or None).""" - ping_results = self.ping() - lame_workers = [] - - for ping_response, device, op in zip(ping_results, self._devices, - self._ops): - if ping_response.health_status != event_pb2.OK: - lame_workers.append((device, op)) - - if not lame_workers: - return None - - bad_devices, bad_ops = zip(*lame_workers) - return WorkerHeartbeatManager(self._session, bad_devices, bad_ops, - self._request_placeholder) - - def __repr__(self): - return 'HeartbeatManager(%s)' % ','.join(self._devices) - - def shutdown(self, timeout_ms=10000): - """Shutdown all workers after `shutdown_timeout_secs`.""" - logging.info('Shutting down %s.', self) - req = event_pb2.WorkerHeartbeatRequest( - watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms)) - self.configure(req) - - # Wait for workers to shutdown. This isn't strictly required - # but it avoids triggering multiple checkpoints with the same lame worker. - logging.info('Waiting %dms for worker shutdown.', timeout_ms) - time.sleep(timeout_ms / 1000) - - -def all_worker_devices(session): - """Return a list of devices for each worker in the system.""" - devices = session.list_devices() - return [ - device.name for device in devices - if ':CPU:' in device.name and 'coordinator' not in device.name - ] - - -class WatchdogManager(threading.Thread): - """Configures worker watchdog timer and handles periodic pings. - - Usage: - # Ping workers every minute, shutting down workers if they haven't received - # a ping after 1 hour. - watchdog_manager = WatchdogManager( - ping_interval=60, shutdown_timeout=3600 - ) - - # Use as a context manager, resetting watchdog on context exit: - with watchdog_manager: - session.run(...) - - # Or setup globally; watchdog will remain active until program exit. - watchdog_manager.configure_and_run() - """ - - def __init__(self, - session, - devices=None, - ping_interval=60, - shutdown_timeout=3600): - """Initialize a watchdog manager. - - Args: - session: Session connected to worker devices. A cloned session and graph - will be created for managing worker pings. - devices: Set of devices to monitor. If none, all workers will be - monitored. - ping_interval: Time, in seconds, between watchdog pings. - shutdown_timeout: Time, in seconds, before watchdog timeout. - """ - threading.Thread.__init__(self) - self.ping_interval = ping_interval - self.shutdown_timeout = shutdown_timeout - self.daemon = True - self._config = session._config # pylint: disable=protected-access - self._target = session.sess_str - self._running = False - self._devices = devices - - self._graph = None - self._session = None - self._worker_manager = None - - def _reset_manager(self): - """Reset the graph, session and worker manager.""" - self._graph = ops.Graph() - self._session = session_lib.Session( - target=self._target, - graph=self._graph, - config=self._config, - ) - - if self._devices is None: - self._devices = all_worker_devices(self._session) - - with self._graph.as_default(): - self._worker_manager = WorkerHeartbeatManager.from_devices( - self._session, self._devices) - - self._worker_manager.configure( - event_pb2.WorkerHeartbeatRequest( - watchdog_config=event_pb2.WatchdogConfig( - timeout_ms=self.shutdown_timeout * 1000,))) - - def configure_and_run(self): - logging.info('Enabling watchdog timer with %d second timeout ' - 'and %d second ping interval.', - self.shutdown_timeout, self.ping_interval) - self._reset_manager() - self._running = True - self.start() - - def stop(self): - logging.info('Stopping worker watchdog.') - self._worker_manager.configure( - event_pb2.WorkerHeartbeatRequest( - watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,))) - self._running = False - self.join() - - def __enter__(self): - self.configure_and_run() - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop() - - def run(self): - # Don't fetch logs or adjust timing: just ping the watchdog. - # - # If we hit an exception, reset our session as it is likely broken. - while self._running: - try: - self._worker_manager.ping(request=None) - time.sleep(self.ping_interval) - except errors.OpError as e: - # Catch any TF errors that occur so we don't stop sending heartbeats - logging.debug('Caught error while sending heartbeat: %s', e) - self._reset_manager() - - -def start_worker_watchdog(session, - devices=None, - ping_interval=60, - shutdown_timeout=3600): - """Start global worker watchdog to shutdown workers on coordinator exit.""" - global _WATCHDOG - if _WATCHDOG is None: - # Ensure we can send a few pings before we timeout! - ping_interval = min(shutdown_timeout / 10., ping_interval) - _WATCHDOG = WatchdogManager(session, devices, ping_interval, - shutdown_timeout) - _WATCHDOG.configure_and_run() - - -class GracefulShutdownHook(session_run_hook.SessionRunHook): - """Session hook that watches for shutdown events. - - If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a - SystemShutdown exception is raised to terminate the main session. If `saver` - is None the `SAVERS` collection will be read to find a saver. - - `on_shutdown_hooks` is an optional list of functions that should be called - after checkpointing. The function is called with (`run_context`, - `all_workers`, `lame_workers`). - - If `heartbeat_group` is not specified, it will default to all CPU workers - in the system. - """ - - def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None): - self._saver = saver - self._checkpoint_prefix = checkpoint_prefix - self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else [] - - # Worker heartbeats are managed independently of the main training graph. - self._graph = ops.Graph() - self._workers = None - self._session = None - self._heartbeat_supported = False - - def after_create_session(self, training_session, coord): # pylint: disable=unused-argument - # N.B. We have to pull the global step here to avoid it being unavailable - # at checkpoint time; the graph has been frozen at that point. - if training_util.get_global_step() is None and self.saver() is not None: - raise ValueError( - 'Saver defined but no global step. Run `get_or_create_global_step()`' - ' in your model definition to allow checkpointing.') - - with self._graph.as_default(): - logging.info('Installing graceful shutdown hook.') - self._session = _clone_session(training_session, self._graph) - self._workers = WorkerHeartbeatManager.from_devices( - self._session, all_worker_devices(self._session)) - self._heartbeat_supported = self._workers.num_workers() > 0 - if self._heartbeat_supported: - self._workers.configure( - event_pb2.WorkerHeartbeatRequest( - shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) - else: - logging.warn( - 'No workers support hearbeats. Failure handling will be disabled.') - - def saver(self): - if self._saver: - return self._saver - - savers = ops.get_collection(ops.GraphKeys.SAVERS) - if not savers: - return None - - if not isinstance(savers, list): - return savers - - if len(savers) > 1: - logging.error( - 'Multiple savers in the SAVERS collection. On-demand checkpointing ' - 'will be disabled. Pass an explicit `saver` to the constructor to ' - 'override this behavior.') - return None - - return savers[0] - - def after_run(self, run_context, run_values): - del run_values - - if not self._heartbeat_supported: - return - - lame_workers = self._workers.lame_workers() - if lame_workers: - logging.info('ShutdownHook: lame workers found: %s', lame_workers) - - if self.saver(): - logging.info('ShutdownHook: saving checkpoint to %s', - self._checkpoint_prefix) - self.saver().save( - run_context.session, - self._checkpoint_prefix, - global_step=training_util.get_global_step(), - write_state=True, - ) - else: - logging.info('ShutdownHook: no Saver defined.') - - for fn in self._on_shutdown_hooks: - fn(run_context, self._workers, lame_workers) - - -class RestartComputation(object): - """Restart the entire computation. - - This hook shuts down all workers and returns control to the top-level by - throwing a CoordinatorShutdownException. - """ - - def __init__(self, timeout_ms=10000): - self.timeout_ms = timeout_ms - - def __call__(self, run_context, all_workers, lame_workers): - del run_context, lame_workers - all_workers.shutdown(timeout_ms=self.timeout_ms) - - logging.info('Terminating coordinator.') - raise CoordinatorShutdownException() - - -class ShutdownLameWorkers(object): - """Shutdown lamed workers. - - Processing will continue normally (typically by waiting for the down - workers to be restarted). - """ - - def __init__(self, timeout_ms=10000): - self.timeout_in_ms = timeout_ms - - def __call__(self, run_context, all_workers, lame_workers): - lame_workers.shutdown(timeout_ms=self.timeout_in_ms) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.session_support import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py index 70baea203cc6174bebc7d90646045efae5f2391d..73db253fd790f26679fb05bd6e7a5da6a99da1a7 100644 --- a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py +++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py @@ -1,553 +1,23 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ======================================================================== -"""A utility to trace tensor values on TPU.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import os.path -import re - -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_util -from tensorflow.python.ops import gen_math_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import tf_logging as logging - -_TRACER_LOG_PREFIX = ' [>>>TT>>>]' -_DEVICE_TYPE_TPU = 'tpu' -_DEVICE_TYPE_CPU = 'cpu' -_GLOBAL_STEP_OP_NAME = 'GLOBAL-STEP' -_TRACE_MODE_NAN_INF = 'nan-inf' -_TRACE_MODE_PART_TENSOR = 'part-tensor' -_TRACE_MODE_PART_TENSOR_SIZE = 3 -_TRACE_MODE_FULL_TENSOR = 'full-tensor' -_RECORD_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' -_RECORD_SHOULD_NOT_TRACE = 'not-traced-should-not-trace' -_RECORD_FILTERED_OUT = 'not-traced-filtered-out' -_RECORD_SCALAR = 'not-traced-scalar' -_RECORD_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' -_RECORD_GET_TRACED = 'get-traced' -_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' -_MARKER_SECTION_END = '!!!!!!! section-end:' -_SECTION_NAME_CONFIG = 'configuration' -_SECTION_NAME_REASON = 'reason' -_SECTION_NAME_OP_LIST = 'op-list' -_SECTION_NAME_GRAPH = 'graph' -_FIELD_NAME_VERSION = 'version:' -_FIELD_NAME_DEVICE = 'device:' -_FIELD_NAME_TRACE_MODE = 'trace-mode:' -_FIELD_NAME_NUM_REPLICAS = 'num-replicas:' -_FIELD_NAME_NUM_OPS = 'number-of-ops:' -_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:' -_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' -_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") -_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') -_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') -_FLAG_NAME_ENABLE = 'enable' -_FLAG_NAME_TRACE_MODE = 'trace_mode' -_FLAG_NAME_INTERESTING_OPS = 'interesting_ops' -_FLAG_NAME_TRACE_FILE = 'trace_file_path' -_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' -_FLAG_NAME_OP_RANGE = 'op_range' -_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') -_OUTPUT_STREAM_ESCAPE = 'file://' -_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' - - -class TensorTracer(object): - """A software construct for tracing tensor values in a TF graph on TPU. - - This utility is disabled by default. It can be enabled by setting - the TENSOR_TRACER_FLAGS env variable as: - export TENSOR_TRACER_FLAGS="--enable=1" - If it is enabled, it will trace the output tensor values of - selected Ops in the graph. It has two outputs: (1) the traces and (2) - a report. The traces are dumped to a specified local file on the TPU - host. The report is printed to the log.info of the TPU job. - By passing options via the env variable, users can change: - (1) the trace mode (e.g., detecting NaN/Inf, printing partial or - full tensor values) - (2) which Ops to be traced (via op.name or op.type) - (3) output trace file path. - """ - - @staticmethod - def _match_next_flag(flags, pos): - """Returns the match for the next TensorTracer flag.""" - - match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos) - if match: - return match - match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos) - if match: - return match - match = _FLAG_NO_QUOTE_PAT.match(flags, pos) - return match - - @staticmethod - def print_flag_values(): - """Prints all TensorTracer flags passed via environment variables.""" - - tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) - if not tensor_tracer_flags: - return 'Env variable "%s" is not set'%_FLAGS_ENV_VAR - result = 'Env variable "%s" is set to "%s"\n'%(_FLAGS_ENV_VAR, - tensor_tracer_flags) - result += 'Individual flag value:\n' - pos = 0 - while True: - match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) - if not match: - break - flag_name = match.group(1) - flag_value = match.group(2) - result += ' %s: %s\n'%(flag_name, flag_value) - pos = match.end() - result += '\n' - return result - - @staticmethod - def get_flag_value(wanted_flag_name): - """Returns the value of a TensorTracer flags.""" - - tensor_tracer_flags = os.getenv(_FLAGS_ENV_VAR) - if not tensor_tracer_flags: - return '' - pos = 0 - while True: - match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) - if not match: - return '' - flag_name = match.group(1) - flag_value = match.group(2) - if flag_name == wanted_flag_name: - return flag_value - pos = match.end() - return '' - - @staticmethod - def is_enabled(): - """Returns True if TensorTracer is enabled.""" - - flag_value = TensorTracer.get_flag_value(_FLAG_NAME_ENABLE) - flag_value = flag_value.lower() - enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] - return enabled - - @staticmethod - def use_test_undeclared_outputs_dir(): - """Decides the output directory of the trace file. - - Args: - None. - - Returns: - True if the output trace file should be written to the - test-undeclared-outputs-directory defined via an - env variable. - """ - - flag_value = TensorTracer.get_flag_value( - _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) - flag_value = flag_value.lower() - enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] - return enabled - - @staticmethod - def check_device_type(device_type): - """Checks if the given device type is valid.""" - - if device_type not in [_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU]: - raise ValueError('Invalid device_type "%s"'%device_type) - - @staticmethod - def check_trace_mode(trace_mode): - """Checks if the given trace mode is valid.""" - - valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, - _TRACE_MODE_FULL_TENSOR] - if trace_mode not in valid_trace_modes: - raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' - 'Valid trace modes are: %s'%(trace_mode, - valid_trace_modes)) - - @staticmethod - def should_trace(device_type, op): - """Returns True if the given Op should be traced.""" - - if device_type != _DEVICE_TYPE_TPU: - raise ValueError('Non TPU device type is not supported') - if control_flow_util.IsInCond(op): - return False - if op.type in ['Reshape', 'ArgMin', 'ArgMax']: - return False - # pylint: disable=protected-access - return tpu._TPU_REPLICATE_ATTR in op.node_def.attr - # pylint: enable=protected-access - - @staticmethod - def reason(op_idx, details): - """Returns why the Op at op_idx is traced or not.""" - return '%d %s'%(op_idx, details) - - @staticmethod - def topological_sort(g): - """Performs topological sort on the given graph. - - Args: - g: the graph. - - Returns: - A pair where the first element indicates if the topological - sort succeeded (True if there is no cycle found; False if a - cycle is found) and the second element is either the sorted - list of nodes or the cycle of nodes found. - """ - - def visit(op, cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops): - """Recursively visits all Ops in a graph. - - Args: - op: the current Op being visited. - cycle: a cycle of Ops found. - permanently_marked_ops: the set of Ops that were already visited. - temporarily_marked_ops: the set of Ops that we have visited during - the current descent. - sorted_ops: the list of Ops sorted in topological order. - """ - - if cycle: - return - if op in permanently_marked_ops: - return - if op in temporarily_marked_ops: - cycle = temporarily_marked_ops - return - temporarily_marked_ops.add(op) - for i in range(len(op.outputs)): - out_tensor = op.outputs[i] - for consumer_op in out_tensor.consumers(): - visit(consumer_op, cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops) - # pylint: disable=protected-access - for ctrl_output_op in op._control_outputs: - # pylint: enable=protected-access - visit(ctrl_output_op, cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops) - temporarily_marked_ops.remove(op) - permanently_marked_ops.add(op) - sorted_ops.insert(0, op) - - graph_cycle = set([]) - sorted_ops = [] - permanently_marked_ops = set([]) - temporarily_marked_ops = set([]) - unsorted_ops = g.get_operations() - for op in unsorted_ops: - visit(op, graph_cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops) - if graph_cycle: - return (False, graph_cycle) - else: - assert len(unsorted_ops) == len(sorted_ops) - return (True, sorted_ops) - - def __init__(self): - """Initializes a TensorTracer. - - Sets the various member fields from the flags (if given) or the defaults. - """ - self._version = 'use-outside-compilation' - self._device_type = None - self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) - if not self._trace_mode: - self._trace_mode = _TRACE_MODE_NAN_INF - TensorTracer.check_trace_mode(self._trace_mode) - self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE - self._instrument_records = {} - interesting_ops = TensorTracer.get_flag_value(_FLAG_NAME_INTERESTING_OPS) - self._selected_ops = interesting_ops.split() - self._set_trace_file_path() - self._set_op_range() - self._num_replicas = None - self._replica_id = None - - def _add_replica_id_to_graph(self, num_replicas, result_tensor): - """Adds nodes for computing the replica ID to the graph.""" - - if not num_replicas: - self._replica_id = 'unknown' - return result_tensor - - self._num_replicas = num_replicas - - with ops.control_dependencies(None): - # Uses None as dependency to run outside of TPU graph rewrites. - self._replica_id = tpu_ops.tpu_replicated_input( - list(range(self._num_replicas)), - name='tt_replica_id') - use_replica_id = array_ops.identity(self._replica_id).op - with ops.control_dependencies([use_replica_id]): - # Adds a control dependency from the result_tensor to - # the replica_id to ensure that replica_id will be added to the graph. - return array_ops.identity(result_tensor) - - def _set_trace_file_path(self): - """Sets the path of the output trace file.""" - - self._trace_file_path = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_FILE) - if not self._trace_file_path: - raise ValueError('--%s is not set in the environment variable %s' - %(_FLAG_NAME_TRACE_FILE, _FLAGS_ENV_VAR)) - elif TensorTracer.use_test_undeclared_outputs_dir(): - if os.path.isabs(self._trace_file_path): - raise ValueError('If use_test_undeclared_outputs_dir is set,' - 'trace_file_path cannot be an absolute path (%s)' - %self._trace_file_path) - outputs_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) - self._trace_file_path = os.path.join(outputs_dir, - self._trace_file_path) - - def _set_op_range(self): - """Sets the index range of the Ops that we will consider tracing.""" - - op_range = TensorTracer.get_flag_value(_FLAG_NAME_OP_RANGE) - if not op_range: - self._op_range = (-1, -1) # this means including all ops. - return - match = _OP_RANGE_PAT.match(op_range) - if not match: - self._op_range = (-1, -1) # this means including all ops. - return - self._op_range = (int(match.group(1)), int(match.group(2))) - - def _inside_op_range(self, idx): - """Return True if the given index is inside the selected range.""" - - if idx < self._op_range[0]: - return False - return self._op_range[1] < 0 or idx <= self._op_range[1] - - def _write_report(self, content): - """Writes the given content to the report.""" - - logging.info('%s %s'%(_TRACER_LOG_PREFIX, content)) - - def _is_selected_op(self, op_name): - """Returns True if the Op with op_name is selected to be traced.""" - - if not self._selected_ops: - return True - if op_name in self._selected_ops: - return True - return False - - def _write_config_section(self): - """Writes the config section of the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG)) - self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, self._version)) - self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, self._device_type)) - self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE, self._trace_mode)) - self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS, self._num_replicas)) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG)) - - def _write_reason_section(self): - """Writes the reason section of the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON)) - for key in sorted(self._instrument_records): - self._write_report('"%s" %s\n'%(key, self._instrument_records[key])) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON)) - - def _write_op_list_section(self, op_list): - """Writes the Op-list section of the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) - self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list))) - for i in range(0, len(op_list)): - self._write_report('%d "%s" %s\n'%(i, op_list[i].name, op_list[i].type)) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) - - def _write_graph_section(self, succeed, sorted_or_cycle): - """Writes the graph section of the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH)) - self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED, - succeed)) - l = list(sorted_or_cycle) - for i in range(0, len(l)): - self._write_report('%d "%s"\n'%(i, l[i].name)) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH)) - - def _make_tensor_trace_fun(self, op_name, output_idx): - """Makes the tensor tracing function called by outside compilation. - - Args: - op_name: the name of the Op that outputs the tensor to be traced. - output_idx: which output of the Op it is (0 means the first output). - - Returns: - A function to be passed as the first argument to outside compilation. - - Raises: - RuntimeError: If the trace mode is invalid. - """ - - def _print_tensor(op_name, output_idx, num_elements, tensor, output_tensor): - """Prints a tensor value to a file. - - Args: - op_name: the name of the Op that outputs the tensor to be printed. - output_idx: which output of the Op it is (0 means the first output). - num_elements: number of elements to print. - tensor: the tensor needs to be returned. - output_tensor: the tensor needs to be printed. - - Returns: - The same tensor passed via the "tensor" argument. - """ - msg = '"%s:%d" '%(op_name, output_idx) - output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path - print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor), - ' @', self._replica_id, - '\n', output_tensor, - summarize=num_elements, - output_stream=output_stream) - with ops.control_dependencies([print_op]): - return array_ops.identity(tensor).op - - def _detect_nan_inf(tensor): - """Trace function for detecting any NaN/Inf in the tensor.""" - - if tensor.dtype.is_floating: - # Since host can't handle bf16, always convert tensor to f32. - tensor = math_ops.cast(tensor, dtypes.float32) - output_tensor = math_ops.reduce_any( - gen_math_ops.logical_or(gen_math_ops.is_nan(tensor), - gen_math_ops.is_inf(tensor))) - else: - output_tensor = constant_op.constant(0) - return _print_tensor(op_name, output_idx, 1, tensor, output_tensor) - - def _show_global_step(tensor): - """Trace function for printing the global step count.""" - - return _print_tensor(op_name, output_idx, 1, tensor, tensor) - - def _show_part_tensor(tensor): - """Trace function for printing part of the tensor.""" - - return _print_tensor(op_name, output_idx, self._part_tensor_size, - tensor, tensor) - - def _show_full_tensor(tensor): - """Trace function for printing the entire tensor.""" - - return _print_tensor(op_name, output_idx, -1, tensor, tensor) - - if op_name == _GLOBAL_STEP_OP_NAME: - return _show_global_step - if self._trace_mode == _TRACE_MODE_NAN_INF: - return _detect_nan_inf - if self._trace_mode == _TRACE_MODE_PART_TENSOR: - return _show_part_tensor - if self._trace_mode == _TRACE_MODE_FULL_TENSOR: - return _show_full_tensor - - raise RuntimeError('Tensor trace fun for %s is not yet implemented' - %self._trace_mode) - - def trace_tpu(self, graph, result_tensor, num_replicas=None): - """Traces the tensors generated by TPU Ops in a TF graph. - - Args: - graph: the graph of Ops. - result_tensor: a result tensor of evaluating the graph. - num_replicas: number of replicas used on the TPU. - - Returns: - A tuple (result_tensor_copy, tracing_ops), where: - result_tensor_copy: an exact copy of result_tensor - tracing_ops: a list of tracing ops. If this list - is non empty, the caller of this function - should pose control dependencies upon these - Ops so that they will be executed when the - graph is evaluated. - """ - - self._device_type = _DEVICE_TYPE_TPU - TensorTracer.check_device_type(self._device_type) - result_tensor_copy = self._add_replica_id_to_graph(num_replicas, - result_tensor) - self._write_config_section() - tracing_ops = [] - operations = graph.get_operations() - self._write_op_list_section(operations) - # Does the topological sort before adding any nodes to the graph. - (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) - for op_id, op in enumerate(operations): - if not self._inside_op_range(op_id): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_OUTSIDE_OP_RANGE) - continue - if not TensorTracer.should_trace(self._device_type, op): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_SHOULD_NOT_TRACE) - continue - if not self._is_selected_op(op.name): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_FILTERED_OUT) - continue - for i in range(len(op.outputs)): - out_tensor = op.outputs[i] - if not out_tensor.get_shape().is_fully_defined(): - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_DYNAMIC_SHAPE) - continue # cannot trace tensors with dynamic shape. - rank = len(out_tensor.shape) - if rank < 1: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_SCALAR) - continue # cannot trace scalar. - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_GET_TRACED) - consumers = out_tensor.consumers() - trace_op = tpu.outside_compilation( - self._make_tensor_trace_fun(op.name, i), out_tensor) - if consumers: - for consumer_op in consumers: - # pylint: disable=protected-access - consumer_op._add_control_input(trace_op) - # pylint: enable=protected-access - else: - # if there is no consumer, we will add the control dependence later - # when we add the control dependency to the output operations. - tracing_ops.append(trace_op) - - self._write_reason_section() - self._write_graph_section(succeed, sorted_or_cycle) - - return (result_tensor_copy, tracing_ops) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tensor_tracer import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py index 6ae718cc2c9716587849aeee8abcd0a1de82a9ae..5bf805752cf51b0a0f4b7400b18b63aae93cf831 100644 --- a/tensorflow/contrib/tpu/python/tpu/topology.py +++ b/tensorflow/contrib/tpu/python/tpu/topology.py @@ -1,220 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Defines the `Topology` class, that describes a TPU fabric topology.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.contrib.tpu.proto import topology_pb2 - - -def _tpu_device_name(job, task, device): - """Returns the device name for the TPU `device` on `task` of `job`.""" - if job is None: - return "/task:%d/device:TPU:%d" % (task, device) - else: - return "/job:%s/task:%d/device:TPU:%d" % (job, task, device) - - -def _tpu_host_device_name(job, task): - """Returns the device name for the CPU device on `task` of `job`.""" - if job is None: - return "/task:%d/device:CPU:0" % task - else: - return "/job:%s/task:%d/device:CPU:0" % (job, task) - - -class Topology(object): - """Describes a set of TPU devices. - - Represents both the shape of the physical mesh, and the mapping between - TensorFlow TPU devices to physical mesh coordinates. - """ - - def __init__(self, serialized=None, mesh_shape=None, device_coordinates=None): - """Builds a Topology object. - - If `serialized` is not `None`, the topology is parsed from `serialized` and - the other arguments are ignored. Otherwise, the topology is computed from - `mesh_shape` and `device_coordinates`. - - Args: - serialized: A serialized `TopologyProto`, or `None`. If not `None`, the - serialized proto is parsed to discover the topology. - mesh_shape: A sequence of 3 positive integers, or `None`. If not `None`, - the shape of the TPU topology, in number of cores. Ignored if - `serialized` is not `None`. - device_coordinates: A rank 3 numpy array that describes the mapping from - TensorFlow TPU devices to TPU fabric coordinates, or `None`. Ignored - if `serialized is not `None`. - - Raises: - ValueError: If `serialized` does not describe a well-formed topology. - ValueError: If `serialized` is `None` and `mesh_shape` is not a sequence - of 3 positive integers. - ValueError: If `serialized` is `None` and `device_coordinates` is not a - rank 3 numpy int32 array that describes a valid coordinate mapping. - """ - - self._serialized = serialized - - if serialized: - self._parse_topology(serialized) - else: - self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32) - self._device_coordinates = np.asarray(device_coordinates, np.int32) - if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1): - raise ValueError("`mesh_shape` must be a sequence of 3 positive " - "entries; got {}".format(self._mesh_shape)) - - if (len(self._device_coordinates.shape) != 3 or - self._device_coordinates.shape[2] != len(self._mesh_shape)): - raise ValueError("`device_coordinates` must be a rank 3 int32 array " - "with minor dimension equal to the mesh shape rank") - - self._topology_tasks, self._topology_devices = self._invert_topology() - - def _parse_topology(self, serialized): - """Parses a serialized `TopologyProto` into `self`.""" - proto = topology_pb2.TopologyProto() - proto.ParseFromString(serialized) - - self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32) - if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1): - raise ValueError("`mesh_shape` must be a vector of size 3 with positive " - "entries; got {}".format(self._mesh_shape)) - - if proto.num_tasks < 0: - raise ValueError("`num_tasks` must be >= 0; got {}".format( - proto.num_tasks)) - if proto.num_tpu_devices_per_task < 0: - raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format( - proto.num_tpu_devices_per_task)) - - expected_coordinates_size = ( - proto.num_tasks * proto.num_tpu_devices_per_task * len( - proto.mesh_shape)) - if len(proto.device_coordinates) != expected_coordinates_size: - raise ValueError("`device_coordinates` must have shape num_tasks ({}) * " - "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); " - "got shape {}".format(proto.num_tasks, - proto.num_tpu_devices_per_task, - proto.mesh_shape, - len(proto.device_coordinates))) - - coords = np.array(proto.device_coordinates, dtype=np.int32) - if any(coords < 0): - raise ValueError("`device_coordinates` must be >= 0") - coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task, - len(proto.mesh_shape))) - self._device_coordinates = coords - - def _invert_topology(self): - """Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps.""" - tasks = np.full(list(self.mesh_shape), -1, dtype=np.int32) - devices = np.full(list(self.mesh_shape), -1, dtype=np.int32) - for task in xrange(self.device_coordinates.shape[0]): - for device in xrange(self.device_coordinates.shape[1]): - x, y, z = self.device_coordinates[task, device, :] - tasks[x, y, z] = task - devices[x, y, z] = device - return tasks, devices - - @property - def mesh_shape(self): - """A rank 1 int32 array describing the shape of the TPU topology.""" - return self._mesh_shape - - @property - def mesh_rank(self): - """Returns the number of dimensions in the mesh.""" - return len(self._mesh_shape) - - @property - def device_coordinates(self): - """Describes the mapping from TPU devices to topology coordinates. - - Returns: - A rank 3 int32 array with shape `[tasks, devices, axis]`. - `tasks` is the number of tasks in the TPU cluster, `devices` is the number - of TPU devices per task, and `axis` is the number of axes in the TPU - cluster topology. Each entry gives the `axis`-th coordinate in the - topology of a task/device pair. TPU topologies are 3-dimensional, with - dimensions `(x, y, core number)`. - """ - return self._device_coordinates - - def task_ordinal_at_coordinates(self, device_coordinates): - """Returns the TensorFlow task number attached to `device_coordinates`. - - Args: - device_coordinates: An integer sequence describing a device's physical - coordinates in the TPU fabric. - - Returns: - Returns the TensorFlow task number that contains the TPU device with those - physical coordinates. - """ - return self._topology_tasks[tuple(device_coordinates)] - - def tpu_device_ordinal_at_coordinates(self, device_coordinates): - """Returns the TensorFlow device number at `device_coordinates`. - - Args: - device_coordinates: An integer sequence describing a device's physical - coordinates in the TPU fabric. - - Returns: - Returns the TensorFlow device number within the task corresponding to - attached to the device with those physical coordinates. - """ - return self._topology_devices[tuple(device_coordinates)] - - def cpu_device_name_at_coordinates(self, device_coordinates, job=None): - """Returns the CPU device attached to a logical core.""" - return _tpu_host_device_name( - job, self._topology_tasks[tuple(device_coordinates)]) - - def tpu_device_name_at_coordinates(self, device_coordinates, job=None): - """Returns the name of the TPU device assigned to a logical core.""" - return _tpu_device_name(job, - self._topology_tasks[tuple(device_coordinates)], - self._topology_devices[tuple(device_coordinates)]) - - @property - def num_tasks(self): - """Returns the number of TensorFlow tasks in the TPU slice.""" - return self._device_coordinates.shape[0] - - @property - def num_tpus_per_task(self): - """Returns the number of TPU devices per task in the TPU slice.""" - return self._device_coordinates.shape[1] - - def serialized(self): - """Returns the serialized form of the topology.""" - if self._serialized is None: - proto = topology_pb2.TopologyProto() - proto.mesh_shape[:] = list(self._mesh_shape) - proto.num_tasks = self._device_coordinates.shape[0] - proto.num_tpu_devices_per_task = self._device_coordinates.shape[1] - proto.device_coordinates.extend(list(self._device_coordinates.flatten())) - self._serialized = proto.SerializeToString() - - return self._serialized +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.topology import * +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index def57da20d6018dcf27ccb7a9d04592f38ce2f7c..5364b20f231ac7af8adf943c3d5e21921b7a06a9 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -1,1189 +1,25 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== - -"""Library of TPU helper functions.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.contrib.compiler import xla -from tensorflow.contrib.framework.python.framework import experimental -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu_function - -from tensorflow.core.framework import attr_value_pb2 -from tensorflow.python.compat import compat as api_compat -from tensorflow.python.framework import device as pydev -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util import compat - - -# Operations that indicate some error in the users graph, e.g. a placeholder -# that's introduced outside of the infeed. -_BLACKLISTED_OPS = set([ - "Placeholder", -]) - -# XLA doesn't currently support reading of intermediate tensors, thus some ops -# are not supported. -_UNSUPPORTED_OPS = set([ - "AudioSummary", - "AudioSummaryV2", - "HistogramSummary", - "ImageSummary", - "MergeSummary", - "Print", - "ScalarSummary", - "TensorSummary", - "TensorSummaryV2", - ]) - -_MAX_WARNING_LINES = 5 - -_TPU_REPLICATE_ATTR = "_tpu_replicate" -_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status" -_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation" - - -def _tpu_system_device_name(job): - """Returns the device name for the TPU_SYSTEM device of `job`.""" - if job is None: - return "/device:TPU_SYSTEM:0" - else: - return "/job:%s/device:TPU_SYSTEM:0" % job - - -def initialize_system(embedding_config=None, job=None): - """Initializes a distributed TPU system for use with TensorFlow. - - Args: - embedding_config: If not None, a `TPUEmbeddingConfiguration` proto - describing the desired configuration of the hardware embedding lookup - tables. If embedding_config is None, no hardware embeddings can be used. - job: The job (the XXX in TensorFlow device specification /job:XXX) that - contains the TPU devices that will be initialized. If job=None it is - assumed there is only one job in the TensorFlow flock, and an error will - be returned if this assumption does not hold. - Returns: - A serialized `TopologyProto` that describes the TPU system. Note: - the topology must be evaluated using `Session.run` before it can be used. - """ - config_string = ("" if embedding_config is None else - embedding_config.SerializeToString()) - with ops.device(_tpu_system_device_name(job)): - return tpu_ops.configure_distributed_tpu(embedding_config=config_string) - - -def shutdown_system(job=None): - """Shuts down a running a distributed TPU system.""" - with ops.device(_tpu_system_device_name(job)): - shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu() - return shutdown_distributed_tpu - - -def core(num): - """Returns the device name for a core in a replicated TPU computation. - - Args: - num: the virtual core number within each replica to which operators should - be assigned. - Returns: - A device name, suitable for passing to `tf.device()`. - """ - return "device:TPU_REPLICATED_CORE:{}".format(num) - - -class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): - """A `ControlFlowContext` for nodes inside a TPU computation. - - The primary role of `TPUReplicateContext` is to mark operators inside a - tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ - is a unique name. - - We use a `ControlFlowContext` to perform the annotation since it integrates - with Tensorflow constructs like ResourceVariables. For example, if a - `ResourceVariable` is constructed inside a tpu.replicate() block, the - `ResourceVariable` implementation can use - `with ops.control_dependencies(None)` to build the variable's definition - outside the replicated computation. - """ - - def __init__(self, name, num_replicas, pivot): - """Builds a new TPUReplicateContext. - - Args: - name: a unique name for the context, used to populate the `_tpu_replicate` - attribute. - num_replicas: an integer that gives the number of replicas for the - computation. - pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any - inputs will have a control dependency on the pivot node. This ensures - that nodes are correctly included in any enclosing control flow - contexts. - """ - super(TPUReplicateContext, self).__init__() - self._num_replicas = num_replicas - self._outer_device_function_stack = None - self._oc_dev_fn_stack = None - self._outside_compilation_cluster = None - self._outside_compilation_counter = 0 - self._in_gradient_colocation = None - self._gradient_colocation_stack = [] - self._host_compute_core = [] - self._name = name - self._name_as_bytes = compat.as_bytes(name) - self._unsupported_ops = [] - self._pivot = pivot - self._replicated_vars = {} - - def get_replicated_var_handle(self, name, vars_): - """Returns a variable handle for replicated TPU variable 'var'. - - This is a method used by an experimental replicated variable implementation - and is not intended as a public API. - - Args: - name: The common name of the variable. - vars_: The replicated TPU variables. - - Returns: - The handle of the TPU replicated input node. - """ - handle = self._replicated_vars.get(name) - if handle is not None: - return handle - - # Builds a TPUReplicatedInput node for the variable, if one does not already - # exist. The TPUReplicatedInput node must belong to the enclosing - # control-flow scope of the TPUReplicateContext. - # TODO(phawkins): consider changing the contract of the TPU encapsulation - # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope - # instead. - - # pylint: disable=protected-access - graph = ops.get_default_graph() - saved_context = graph._get_control_flow_context() - graph._set_control_flow_context(self.outer_context) - handle = tpu_ops.tpu_replicated_input( - [v.handle for v in vars_], name=name + "/handle") - graph._set_control_flow_context(saved_context) - # pylint: enable=protected-access - self._replicated_vars[name] = handle - return handle - - def report_unsupported_operations(self): - if self._unsupported_ops: - op_str = "\n".join([" %s (%s)" % (op.type, op.name) - for op in self._unsupported_ops[:_MAX_WARNING_LINES]]) - logging.warning("%d unsupported operations found: \n%s", - len(self._unsupported_ops), op_str) - if len(self._unsupported_ops) > _MAX_WARNING_LINES: - logging.warning("... and %d more" % - (len(self._unsupported_ops) - _MAX_WARNING_LINES)) - - def EnterGradientColocation(self, op, gradient_uid): - if op is not None: - self._gradient_colocation_stack.append(op) - if not self._outside_compilation_cluster: - try: - outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR) - if self._in_gradient_colocation: - raise NotImplementedError( - "Cannot nest gradient colocation operations outside compilation" - ) - if gradient_uid == "__unsupported__": - raise NotImplementedError( - "No gradient_uid calling gradient within outside_compilation") - # When we take the gradient of an op X in an outside_compilation - # cluster C in a forward computation we would like to put the ops - # corresponding to the gradient of X into a new outside_compilation - # cluster C'. However, if we take the gradient of X twice, the second - # one should get yet another new outside_compilation cluster C''. - # - # The mechanism we adopt is to use a 'root_cluster' which is the - # cluster that X was in before we took gradients, and a 'gradient_uid' - # which is different for every invocation of gradients, and put the - # gradient of X in cluster 'root_cluster.gradient_uid'. - # - # When taking a gradient of a gradient, some ops will be colocated - # with Op in the forward pass (e.g., cluster root_cluster) and some in - # the backward pass (e.g., cluster root_cluster.initial_gradient_uid). - # We need all of the grad-of-grad ops to be in the same cluster to - # avoid cyclic dependencies between clusters. We adopt a heuristic - # that puts any op clustered with root_cluster. in - # root_cluster.gradient_uid, even if xxx was initial_gradient_uid. - self._in_gradient_colocation = op - parts = outside_attr.split(".") - cluster = parts[0] + "." + gradient_uid - self._EnterOutsideCompilationScope(cluster=cluster) - except ValueError: - # The attr was not present: do nothing. - pass - - def ExitGradientColocation(self, op, gradient_uid): - if op is not None: - if not self._gradient_colocation_stack: - raise errors.InternalError( - op.node_def, op, - "Badly nested gradient colocation: empty stack when popping Op " + - op.name) - last_op = self._gradient_colocation_stack.pop() - if op is last_op: - if op is self._in_gradient_colocation: - self._in_gradient_colocation = None - self._ExitOutsideCompilationScope() - else: - raise errors.InternalError( - op.node_def, op, "Badly nested gradient colocation, expected " + - last_op + ", got " + op.name) - - def _EnterOutsideCompilationScope(self, cluster=None): - - class FakeOp(object): - """A helper class to determine the current device. - - Supports only the type and device set/get methods needed to run the - graph's _apply_device_function method. - """ - - def __init__(self): - self._device = "" - - @property - def type(self): - return "FakeOp" - - @property - def device(self): - return self._device - - def _set_device(self, device): - if isinstance(device, pydev.DeviceSpec): - self._device = device.to_string() - else: - self._device = device - - if self._outside_compilation_cluster: - raise NotImplementedError("Cannot nest outside_compilation clusters") - if cluster: - self._outside_compilation_cluster = cluster - else: - self._outside_compilation_cluster = str(self._outside_compilation_counter) - self._outside_compilation_counter += 1 - graph = ops.get_default_graph() - fake_op = FakeOp() - graph._apply_device_functions(fake_op) # pylint: disable=protected-access - device = pydev.DeviceSpec.from_string(fake_op.device) - if (device.device_type == "TPU_REPLICATED_CORE" and - device.device_index is not None): - self._host_compute_core.append(self._outside_compilation_cluster + ":" + - str(device.device_index)) - self._oc_dev_fn_stack = graph._device_function_stack # pylint: disable=protected-access - graph._device_function_stack = self._outer_device_function_stack # pylint: disable=protected-access - - def _ExitOutsideCompilationScope(self): - if not self._outside_compilation_cluster: - raise NotImplementedError( - "Attempted to exit outside_compilation scope when not in scope") - self._outside_compilation_cluster = None - graph = ops.get_default_graph() - graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access - - def Enter(self): - if not self._outer_device_function_stack: - # Capture the device function stack at the time of first entry - # since that is the stack that will be used outside_compilation. - graph = ops.get_default_graph() - # pylint: disable=protected-access - self._outer_device_function_stack = graph._device_function_stack.copy() - # pylint: enable=protected-access - super(TPUReplicateContext, self).Enter() - - def HostComputeCore(self): - return self._host_compute_core - - def AddOp(self, op): - # pylint: disable=protected-access - if op.type in _BLACKLISTED_OPS: - logging.error("Operation of type %s (%s) is not supported on the TPU. " - "Execution will fail if this op is used in the graph. " % - (op.type, op.name)) - - if op.type in _UNSUPPORTED_OPS: - self._unsupported_ops.append(op) - - if any(x.dtype._is_ref_dtype for x in op.inputs): - raise NotImplementedError( - "Non-resource Variables are not supported inside TPU computations " - "(operator name: %s)" % op.name) - if _TPU_REPLICATE_ATTR in op.node_def.attr: - raise ValueError("TPU computations cannot be nested") - op._set_attr(_TPU_REPLICATE_ATTR, - attr_value_pb2.AttrValue(s=self._name_as_bytes)) - if self._outside_compilation_cluster: - op._set_attr( - _OUTSIDE_COMPILATION_ATTR, - attr_value_pb2.AttrValue( - s=compat.as_bytes(self._outside_compilation_cluster))) - if self._num_replicas > 1 or not self._outside_compilation_cluster: - # Prevent feeding or fetching anything that is being compiled, - # and any replicated outside_compilation Op. - op.graph.prevent_feeding(op) - op.graph.prevent_fetching(op) - - # Remove any control edges from outer control flow contexts. These may cause - # mismatched frame errors. - (internal_control_inputs, - external_control_inputs) = self._RemoveExternalControlEdges(op) - - if not op.inputs: - # Add a control edge from the control pivot to this op. - if not internal_control_inputs: - # pylint: disable=protected-access - op._add_control_input(self.GetControlPivot()) - # pylint: enable=protected-access - else: - for index in xrange(len(op.inputs)): - x = op.inputs[index] - real_x = self.AddValue(x) - if real_x != x: - op._update_input(index, real_x) # pylint: disable=protected-access - - if external_control_inputs: - # Use an identity to pull control inputs as data inputs. Note that we - # ignore ops which don't have outputs. TODO(phawkins): fix that. - external_control_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_control_inputs - if x.outputs - ] - # pylint: disable=protected-access - op._add_control_inputs(external_control_inputs) - # pylint: enable=protected-access - - # Mark op's outputs as seen by this context and any outer contexts. - output_names = [x.name for x in op.outputs] - context = self - while context is not None: - # pylint: disable=protected-access - context._values.update(output_names) - context = context._outer_context - # pylint: enable=protected-access - - if self._outer_context: - self._outer_context.AddInnerOp(op) - - def AddValue(self, val): - """Add `val` to the current context and its outer context recursively.""" - if val.name in self._values: - # Use the real value if it comes from outer context. - result = self._external_values.get(val.name) - return val if result is None else result - - result = val - self._values.add(val.name) - if self._outer_context: - result = self._outer_context.AddValue(val) - self._values.add(result.name) - - self._external_values[val.name] = result - - return result - - def AddInnerOp(self, op): - self.AddOp(op) - if self._outer_context: - self._outer_context.AddInnerOp(op) - - @property - def grad_state(self): - # Define the gradient loop state associated with the TPUReplicateContext to - # be None as the TPUReplicateContext does not get nested nor does the - # grad_state outside the TPUReplicateContext affect the graph inside so the - # grad_state should be as if this is the top-level gradient state. - return None - - @property - def back_prop(self): - """Forwards to the enclosing while context, if any.""" - if self.GetWhileContext(): - return self.GetWhileContext().back_prop - return False - - def GetControlPivot(self): - return self._pivot - - -def outside_compilation(computation, *args, **kwargs): - """Builds part of a computation outside any current TPU replicate scope. - - Args: - computation: A Python function that builds the computation to - place on the host. - *args: the positional arguments for the computation. - **kwargs: the keyword arguments for the computation. - - Returns: - The Tensors returned by computation. - """ - args = [] if args is None else args - graph = ops.get_default_graph() - - # If we are in a TPUReplicateContext, signal that we are now - # outside_compilation - initial_context = graph._get_control_flow_context() # pylint: disable=protected-access - context = initial_context - while context: - if isinstance(context, TPUReplicateContext): - context._EnterOutsideCompilationScope() # pylint: disable=protected-access - context = context.outer_context - - retval = computation(*args, **kwargs) - - # If we are in a TPUReplicateContext, signal that we are no longer - # outside_compilation - final_context = graph._get_control_flow_context() # pylint: disable=protected-access - if initial_context is not final_context: - raise NotImplementedError( - "Control-flow context cannot be different at start and end of an " - "outside_compilation scope") - context = initial_context - while context: - if isinstance(context, TPUReplicateContext): - context._ExitOutsideCompilationScope() # pylint: disable=protected-access - context = context.outer_context - - return retval - - -def replicate(computation, - inputs=None, - infeed_queue=None, - device_assignment=None, - name=None): - """Builds a graph operator that runs a replicated TPU computation. - - Args: - computation: A Python function that builds the computation to replicate. - inputs: A list of lists of input tensors or `None` (equivalent to - `[[]]`), indexed by `[replica_num][input_num]`. All replicas must - have the same number of inputs. - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to computation. - device_assignment: If not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. Uses a default device assignment if `None`. The - `DeviceAssignment` may be omitted if each replica of the computation uses - only one core, and there is either only one replica, or the number of - replicas is equal to the number of cores in the TPU system. - name: (Deprecated) Does nothing. - Returns: - A list of lists of output tensors, indexed by `[replica_num][output_num]`. - Raises: - ValueError: If all replicas do not have equal numbers of input tensors. - ValueError: If the number of inputs per replica does not match - the number of formal parameters to `computation`. - """ - return split_compile_and_replicate(computation, inputs, infeed_queue, - device_assignment, name)[1] - - -def split_compile_and_replicate(computation, - inputs=None, - infeed_queue=None, - device_assignment=None, - name=None, - use_tpu=True): - """Builds graph operators that runs compilation and replicated computation. - - This is a lower level interface than replicate that returns a separate compile - and execute output tensor. In the generated graph the compile op feeds into - the execute op and no additional compilation is incurred when running the - compile op before the execute op. The compile op returns additional - information about the compilation but does not return the compiled program. - - Args: - computation: A Python function that builds the computation to replicate. - inputs: A list of lists of input tensors or `None` (equivalent to - `[[]]`), indexed by `[replica_num][input_num]`. All replicas must - have the same number of inputs. - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to computation. - device_assignment: If not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. Uses a default device assignment if `None`. The - `DeviceAssignment` may be omitted if each replica of the computation uses - only one core, and there is either only one replica, or the number of - replicas is equal to the number of cores in the TPU system. - name: (Deprecated) Does nothing. - use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU - backends. Currently, only supports a default placement (computation is - placed on GPU if one is available, and on CPU if not). - Returns: - A list of lists with the first list corresponding to the compile op and the - second a list of output tensors, indexed by `[replica_num][output_num]`. - Raises: - ValueError: If all replicas do not have equal numbers of input tensors. - ValueError: If the number of inputs per replica does not match - the number of formal parameters to `computation`. - """ - del name - inputs = [[]] if inputs is None else inputs - - metadata_kwargs = {} - if device_assignment is not None: - # Turn the Numpy array into a flattened list so we can pass it as an - # operator attribute. - metadata_kwargs = { - "topology": - device_assignment.topology.serialized(), - "device_assignment": - device_assignment.core_assignment.flatten().tolist() - } - # TODO(phawkins): remove this case after the forward compatibility window - # expires on 2018-10-5. - if api_compat.forward_compatible(2018, 10, 5): - metadata_kwargs["num_cores_per_replica"] = ( - device_assignment.num_cores_per_replica) - else: - metadata_kwargs["computation_shape"] = [ - device_assignment.num_cores_per_replica - ] - - if ((not isinstance(inputs, list)) or - any(not isinstance(inp, (list, tuple)) for inp in inputs)): - raise TypeError("tpu.replicate() inputs must be a list of lists/tuples") - - num_replicas = len(inputs) - - # No replicas? Nothing to do. - if num_replicas == 0: - return [] - - # Converts inputs to Tensors. - inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] - - # Verifies that all replicas have matching numbers and types of inputs - input_types = [x.dtype for x in inputs[0]] - input_arity = len(input_types) - for i in range(num_replicas): - if len(inputs[i]) != input_arity: - raise ValueError("Replicas must have the same number of inputs. " - "Replica 0 had {} inputs, replica {} had {} " - "inputs.".format(input_arity, i, len(inputs[i]))) - - types = [x.dtype for x in inputs[i]] - if types != input_types: - raise ValueError( - "Replicas must have matching input types. Replica 0 had " - "input types {}, replica {} had input types {}".format( - input_types, i, types)) - - arg_error = xla.check_function_argument_count( - computation, input_arity, infeed_queue) - if arg_error is not None: - if infeed_queue is None: - raise TypeError( - "Supplied computation cannot be called with the specified inputs. " - "You specified %d inputs: %s, but the computation needs %s" % ( - input_arity, str([i.name for i in inputs[0]]), arg_error)) - else: - raise TypeError( - "Supplied computation cannot be called with the specified inputs. " - "You specified %d inputs: %s and %d additional inputs from infeed," - " but the computation needs %s" % (input_arity, str( - [i.name - for i in inputs[0]]), infeed_queue.number_of_tuple_elements, - arg_error)) - - graph = ops.get_default_graph() - - # Fan-in: Builds a TPUReplicatedInput node for each input. - computation_inputs = [] - for i in range(0, input_arity): - replicas = [inputs[replica][i] for replica in xrange(num_replicas)] - computation_inputs.append( - tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) - - cluster_name = graph.unique_name("cluster") - pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") - context = TPUReplicateContext( - name=cluster_name, num_replicas=num_replicas, pivot=pivot) - try: - context.Enter() - - metadata = tpu_ops.tpu_replicate_metadata( - num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) - - with tpu_function.tpu_shard_context( - num_replicas), ops.control_dependencies([metadata]): - - # Add identity ops so even unused inputs are "consumed" by the - # computation. This is to avoid orphaned TPUReplicatedInput nodes. - # TODO(phawkins): consider instead pruning unused TPUReplicatedInput - # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. - computation_inputs = [ - array_ops.identity(x, name="replicated_input_{}".format(i)) - for i, x in enumerate(computation_inputs) - ] - - # If there is an infeed queue, adds the dequeued values to the - # computation's inputs. - if infeed_queue is not None: - infeed_queue.set_number_of_shards(num_replicas) - for t in infeed_queue.generate_dequeue_op(): - computation_inputs.append(t) - - # Only resource variables work inside a TPU computation, so turn on - # resource variables for the computation. - # TODO(phawkins): consider removing this code. It will - # be less confusing to clients if they knowingly choose to use resource - # variables. - # Partitioned variables is not supported (b/112311320). - vscope = variable_scope.get_variable_scope() - saved_use_resource = vscope.use_resource - saved_custom_getter = vscope.custom_getter - - def custom_getter(getter, name, *args, **kwargs): - """Variables on TPU have a few restrictions.""" - partitioner = kwargs["partitioner"] - if partitioner is not None: - kwargs["partitioner"] = None - logging.warning( - "Partitioned variables are not supported on TPU. Got " - "`partitioner` that is {} for variable {}. " - "Setting `partitioner` to `None`." - .format(partitioner, name)) - if saved_custom_getter is None: - return getter(name, *args, **kwargs) - else: - return saved_custom_getter(getter, name, *args, **kwargs) - - vscope.set_use_resource(True) - vscope.set_custom_getter(custom_getter) - - outputs = computation(*computation_inputs) - - vscope.set_use_resource(saved_use_resource) - vscope.set_custom_getter(saved_custom_getter) - - # If the computation returns `None`, make it an empty tuple. - if outputs is None: - outputs = tuple() - # If the computation only returned one value, makes it a tuple. - if not isinstance(outputs, (list, tuple)): - outputs = (outputs,) - - # Append `no_op` here so that fetching any return value of this function - # will trigger TPUExecute node. - outputs += (control_flow_ops.no_op(),) - try: - with ops.device(core(0)): - outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) - for o in outputs - ] - except Exception as e: - raise ValueError( - "TPU function return values must all either be Operations or " - "convertible to Tensors. Got '%s'" % str(e)) - - # Separates the returned Operations and Tensors. - output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] - - if outputs != output_tensors + output_operations: - raise ValueError( - "TPU functions must return zero-or more Tensor values followed by " - "zero or more Operations.") - output_arity = len(output_tensors) - - # Wraps outputs in Identity ops. Otherwise a replicated input copied - # straight to an output would bypass the replicate(). This would be bad - # because the TPUReplicatedInput/TPUReplicatedOutput operator would not - # be rewritten away, leading to a runtime error. - # TODO(phawkins): extend the rewrite to elide these nodes instead. - new_output_tensors = [] - for t in output_tensors: - with ops.device(t.device if t.device else core(0)): - new_output_tensors.append(array_ops.identity(t)) - output_tensors = new_output_tensors - context.ExitResult(output_tensors) - finally: - context.report_unsupported_operations() - context.Exit() - host_compute_core = context.HostComputeCore() - - if host_compute_core: - attr_value = attr_value_pb2.AttrValue() - attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core]) - metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access - - # Fan-out: Builds a TPUReplicatedOutput node for each output. - outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, - name="output{}".format(i)) - for i in xrange(output_arity)] - - with ops.control_dependencies([metadata]): - if use_tpu: - compile_status = tpu_ops.tpu_compilation_result() - op = compile_status.op - attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) - op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access - else: - compile_status = control_flow_ops.no_op(name="compilation_status") - - with ops.control_dependencies(output_operations): - if output_arity == 0: - # Returns a list of NoOps dependent on the replication Op, indexed by - # [replica_num]. - return [ - compile_status, [ - control_flow_ops.no_op(name="shard_%d" % i) - for i in range(num_replicas) - ] - ] - else: - # Wraps the outputs in identity operators so the names of any possible - # `fetch` nodes are preserved by the replication rewrite. - return [ - compile_status, [[ - array_ops.identity( - outputs[out][replica], - name="output_%d_shard_%d" % (out, replica)) - for out in xrange(output_arity) - ] - for replica in xrange(num_replicas)] - ] - - -def shard(computation, - inputs=None, - num_shards=1, - input_shard_axes=None, - outputs_from_all_shards=True, - output_shard_axes=None, - infeed_queue=None, - device_assignment=None, - name=None): - """Shards `computation` for parallel execution. - - `inputs` must be a list of Tensors or None (equivalent to an empty list), each - of which has a corresponding split axis (from `input_shard_axes`). Each input - is split into `num_shards` pieces along the corresponding axis, and - computation is applied to each shard in parallel. - - Tensors are broadcast to all shards if they are lexically captured by - `computation`. e.g., - - x = tf.constant(7) - def computation(): - return x + 3 - ... = shard(computation, ...) - - TODO(phawkins): consider adding support for broadcasting Tensors passed - as inputs. - - If `outputs_from_all_shards` is true, the outputs from all shards of - `computation` are concatenated back together along their `output_shards_axes`. - Otherwise, each output is taken from an arbitrary shard. - - Inputs and outputs of the computation must be at least rank-1 Tensors. - - Args: - computation: A Python function that builds a computation to apply to each - shard of the input. - inputs: A list of input tensors or None (equivalent to an empty list). Each - input tensor has a corresponding shard axes, given by `input_shard_axes`, - which must have size divisible by `num_shards`. - num_shards: The number of shards. - input_shard_axes: A list of dimensions along which to shard `inputs`, or - `None`. `None` means "shard all inputs along dimension 0". If not `None`, - there must be one dimension per input. - outputs_from_all_shards: Boolean or list of boolean. For each output, if - `True`, outputs from all shards are concatenated along the corresponding - `output_shard_axes` entry. Otherwise, each output is taken - from an arbitrary shard. If the argument is a boolean, the argument's - value is used for each output. - output_shard_axes: A list of dimensions along which to concatenate the - outputs of `computation`, or `None`. `None` means "concatenate all outputs - along dimension 0". If not `None`, there must be one dimension per output. - Ignored if `outputs_from_all_shards` is False. - infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs - of `computation`. - device_assignment: If not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. Uses a default device assignment if `None`. The - `DeviceAssignment` may be omitted if each shard of the computation uses - only one core, and there is either only one shard, or the number of shards - is equal to the number of cores in the TPU system. - name: (Deprecated) Does nothing. - Returns: - A list of output tensors. - Raises: - ValueError: If num_shards <= 0 - ValueError: If len(input_shard_axes) != len(inputs) - ValueError: If len(output_shard_axes) != len(outputs from `computation`) - """ - - if num_shards <= 0: - raise ValueError("num_shards must be a positive integer.") - - inputs = [] if inputs is None else inputs - if not isinstance(inputs, list): - raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None.") - - # Converts inputs to Tensors. - inputs = [ops.convert_to_tensor(x) for x in inputs] - - if input_shard_axes is None: - input_shard_axes = [0] * len(inputs) - if len(inputs) != len(input_shard_axes): - raise ValueError("Length of input_shard_axes must be equal to the number " - "of inputs.") - - if inputs: - # Splits the `inputs` along the corresponding `input_shard_axes`, giving - # lists with layout [input][shard] - split_inputs = [ - array_ops.split(x, num_shards, axis=axis) - for (axis, x) in zip(input_shard_axes, inputs)] - - # Transposes the input lists to have layout [shard][input] - transposed_inputs = [list(i) for i in zip(*split_inputs)] - else: - transposed_inputs = [[]] * num_shards - - outputs = replicate( - computation, - transposed_inputs, - infeed_queue=infeed_queue, - device_assignment=device_assignment, - name=name) - - # There must be at least one shard since num_shards > 0. - # TODO(b/36647078) remove disable when pylint bug is fixed. - # pylint: disable=indexing-exception - if isinstance(outputs[0], ops.Operation): - # pylint: enable=indexing-exception - # There were no outputs from the computation and replicate returned a list - # of NoOps with control dependencies on the computation. Return the first - # one so it can be used as a control dependency or fetch node. - # TODO(b/36647078) remove disable when pylint bug is fixed. - # pylint: disable=indexing-exception - return [outputs[0]] - # pylint: enable=indexing-exception - - # TODO(b/36647078) remove disable when pylint bug is fixed. - # pylint: disable=indexing-exception - num_outputs = len(outputs[0]) - # pylint: enable=indexing-exception - - if output_shard_axes is None: - output_shard_axes = [0] * num_outputs - if num_outputs != len(output_shard_axes): - raise ValueError("Length of output_shard_axes must be equal to the number " - "of outputs.") - - if isinstance(outputs_from_all_shards, bool): - outputs_from_all_shards = [outputs_from_all_shards] * num_outputs - - if num_outputs != len(outputs_from_all_shards): - raise ValueError("Length of outputs_from_all_shards must be equal to the " - "number of outputs.") - - results = [] - for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards, - zip(*outputs)): - if all_shards: - # Concatenate all of the outputs together (use stack for scalars). - shape = x[0].shape - is_scalar = shape is not None and (shape.ndims == 0) - results.append((array_ops.stack(list(x)) if is_scalar - else array_ops.concat(list(x), axis=axis))) - else: - # TODO(phawkins): use a smarter policy, e.g., round-robin across shards. - results.append(x[0]) - - return results - - -def batch_parallel(computation, - inputs=None, - num_shards=1, - infeed_queue=None, - device_assignment=None, - name=None): - """Shards `computation` along the batch dimension for parallel execution. - - Convenience wrapper around shard(). - - `inputs` must be a list of Tensors or None (equivalent to an empty list). - Each input is split into `num_shards` pieces along the 0-th dimension, and - computation is applied to each shard in parallel. - - Tensors are broadcast to all shards if they are lexically captured by - `computation`. e.g., - - x = tf.constant(7) - def computation(): - return x + 3 - ... = shard(computation, ...) - - The outputs from all shards are concatenated back together along their 0-th - dimension. - - Inputs and outputs of the computation must be at least rank-1 Tensors. - - Args: - computation: A Python function that builds a computation to apply to each - shard of the input. - inputs: A list of input tensors or None (equivalent to an empty list). The - 0-th dimension of each Tensor must have size divisible by `num_shards`. - num_shards: The number of shards. - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to `computation`. - device_assignment: If not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. Uses a default device assignment if `None`. The - `DeviceAssignment` may be omitted if each shard of the computation uses - only one core, and there is either only one shard, or the number of shards - is equal to the number of cores in the TPU system. - name: (Deprecated) Does nothing. - Returns: - A list of output tensors. - Raises: - ValueError: If `num_shards <= 0` - """ - return shard( - computation, - inputs, - num_shards=num_shards, - infeed_queue=infeed_queue, - device_assignment=device_assignment, - name=name) - - -def rewrite(computation, - inputs=None, - infeed_queue=None, - device_assignment=None, - name=None): - """Rewrites `computation` for execution on a TPU system. - - Args: - computation: A Python function that builds a computation to apply to the - input. If the function takes n inputs, 'inputs' should be a list of n - tensors. - - `computation` may return a list of operations and tensors. Tensors must - come before operations in the returned list. The return value of - `rewrite` is a list of tensors corresponding to the tensors from the - output of `computation`. - - All `Operation`s constructed during `computation` will be executed when - evaluating any of the returned output tensors, not just the ones returned. - inputs: A list of input tensors or `None` (equivalent to an empty list). - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to `computation`. - device_assignment: if not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. May be omitted for a single-core computation, in which - case the core attached to task 0, TPU device 0 is used. - name: (Deprecated) Does nothing. - Returns: - A list of output tensors. - """ - if inputs is not None and not isinstance(inputs, (list, tuple)): - raise TypeError("tpu.rewrite() inputs must be a list or tuple") - - # TODO(b/36647078) remove disable when pylint bug is fixed. - # pylint: disable=indexing-exception - return replicate( - computation, - None if inputs is None else [inputs], - infeed_queue=infeed_queue, - device_assignment=device_assignment, - name=name)[0] - # pylint: enable=indexing-exception - - # Operations that indicate some error in the user's inference graph. -_BLACKLISTED_INFERENCE_OPS = set([ - "ReadVariableOp", - "AssignVariableOp", - "AssignAddVariableOp", - "AssignSubVariableOp", - "VarHandleOp", - "Variable", - "VariableV2", -]) - - -def under_tpu_inference_context(): - """Check if it is currently under `tpu.rewrite_for_inference()`.""" - graph = ops.get_default_graph() - - context = graph._get_control_flow_context() # pylint: disable=protected-access - while context: - if isinstance(context, _TPUInferenceContext): - return True - context = context.outer_context - - return False - - -class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): - """A `ControlFlowContext` for nodes inside a TPU inference computation. - - The primary role of `TPUReplicateContext` is to sanity check operators inside - a tpu.rewrite_for_inference() computation. - """ - - def __init__(self, name): - super(_TPUInferenceContext, self).__init__() - self._name = name - - def AddOp(self, op): - self._AddOpInternal(op) - - def _AddOpInternal(self, op): - # pylint: disable=protected-access - if op.type in _BLACKLISTED_INFERENCE_OPS: - raise NotImplementedError( - "Operation of type %s (%s) is not supported on the TPU for inference." - " Execution will fail if this op is used in the graph. Make sure your" - " variables are using variable_scope." % (op.type, op.name)) - if self._outer_context: - self._outer_context.AddInnerOp(op) - - def AddValue(self, val): - result = val - if self._outer_context: - result = self._outer_context.AddValue(val) - return result - - def AddInnerOp(self, op): - self._AddOpInternal(op) - - @property - def grad_state(self): - return None - - -@experimental -def validate_inference_rewrite_for_variables(graph): - """Validates whether rewrite_for_inference() 'worked' for variables. - - The rewrite_for_inference() method is supposed to append GuaranteeConstOps - after ReadVariableOps, but this mechanism works only if you are using - tf.get_variable() to create and access variables in your tpu computation. - This validation method can be called immediately after calling - tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added - to the graph. - - Typical usages: - tpu.validate_inference_rewrite_for_variables(tf.get_default_graph()) - - tpu.validate_inference_rewrite_for_variables(sess.graph) - - Args: - graph: The graph which needs to be validated. - Raises: - RuntimeError: if validation failed. - """ - if not any(x.type == "GuaranteeConst" for x in graph.get_operations()): - raise RuntimeError( - "No GuaranteeConst ops found in the graph after running " - "tpu.rewrite_for_inference(...). Please check that you are using " - "tf.get_variable() to create and access variables in your tpu " - "computation.") - - -@experimental -def rewrite_for_inference(computation, - inputs=None, - infeed_queue=None, - device_assignment=None, - name=None): - """Rewrites `computation` for inference on a TPU system. - - Other than 'rewriting' the computation to run on a TPU, if using variables - in your computation, it moves the ReadVariableOps outside the TPU - computation, and adds GuaranteeConst ops just after the ReadVariableOps. - This mechanism works only if you are using tf.get_variable() to create and - access variables in your tpu computation. You can validate whether this - worked, by calling validate_inference_rewrite_for_variables() method - immediately after this method to check whether GuaranteeConstOps where - added to the graph. - - Args: - computation: A Python function that builds a computation to apply to the - input. If the function takes n inputs, 'inputs' should be a list of n - tensors. If the function returns m outputs, rewrite will return a list of - m tensors. - inputs: A list of input tensors or `None` (equivalent to an empty list). - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to `computation`. - device_assignment: if not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. May be omitted for a single-core computation, in which - case the core attached to task 0, TPU device 0 is used. - name: The name of the operator. - Returns: - A list of output tensors. - """ - - def guarantee_const_getter(getter, name, *args, **kwargs): - with ops.control_dependencies(None): - return array_ops.guarantee_const( - getter(name, *args, **kwargs), name=name + "/GuaranteeConst") - - def wrapped_computation(*args, **kwargs): - """Execute computation under `_TPUInferenceContext`.""" - context = _TPUInferenceContext( - name=ops.get_default_graph().unique_name("rewrite_for_inference")) - try: - context.Enter() - - vscope = variable_scope.get_variable_scope() - prev_custom_getter = vscope.custom_getter - prev_caching_device = vscope.caching_device - vscope.set_custom_getter(guarantee_const_getter) - vscope.set_caching_device(lambda op: op.device) - - result = computation(*args, **kwargs) - - vscope.set_custom_getter(prev_custom_getter) - vscope.set_caching_device(prev_caching_device) - finally: - context.Exit() - return result - - # pylint: disable=undefined-variable - return rewrite( - wrapped_computation, - inputs=inputs, - infeed_queue=infeed_queue, - device_assignment=device_assignment, - name=name) - # pylint: enable=undefined-variable +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.tpu import * +# used by tests +from tensorflow.python.tpu.tpu import _TPU_REPLICATE_ATTR +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 9f8d14706845baa1ed45c84b2c15d372915a0eb4..c36aaa38c0e4823bfc438773e4aa5b5109794da4 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -1,275 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== - -"""A RunConfig subclass with TPU support.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import json -import os - -from tensorflow.contrib.tpu.python.tpu import util as util_lib -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.estimator import run_config as run_config_lib -from tensorflow.python.platform import tf_logging as logging - -# pylint: disable=protected-access -_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV -_SERVICE_KEY = run_config_lib._SERVICE_KEY -_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name' -# pylint: enable=protected-access - - -class InputPipelineConfig(object): - r"""Please see the definition of these values in TPUConfig.""" - PER_SHARD_V1 = 1 - PER_HOST_V1 = 2 - PER_HOST_V2 = 3 - BROADCAST = 4 - - -class TPUConfig( - collections.namedtuple('TPUConfig', [ - 'iterations_per_loop', - 'num_shards', - 'num_cores_per_replica', - 'per_host_input_for_training', - 'tpu_job_name', - 'initial_infeed_sleep_secs', - 'input_partition_dims', - ])): - r"""TPU related configuration required by `TPUEstimator`. - - Args: - iterations_per_loop: This is the number of train steps running in TPU - system before returning to CPU host for each `Session.run`. This means - global step is increased `iterations_per_loop` times in one `Session.run`. - It is recommended to be set as number of global steps for next checkpoint. - num_shards: (Deprecated, ignored by TPUEstimator). - The number of model replicas in the system. For non-model-parallelism - case, this number equals the total number of TPU cores. For - model-parallelism, the total number of TPU cores equals - num_cores_per_replica * num_shards. - num_cores_per_replica: Defaults to `None`, which disables model parallelism. - An integer which describes the number of TPU cores per model replica. This - is required by model-parallelism which enables partitioning - the model to multiple cores. Currently num_cores_per_replica must be - 1, 2, 4, or 8. - per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`, - `input_fn` is invoked once on each host. With the per-core input pipeline - configuration, it is invoked once for each core. - With a global batch size `train_batch_size` in `TPUEstimator` constructor, - the batch size for each shard is `train_batch_size` // #hosts in the - `True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is - `train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only - invoked once on host 0 and the tensors are broadcasted to all other - replicas. The batch size equals to train_batch_size`. With the per-core - input pipeline configuration, the shard batch size is also - `train_batch_size` // #cores. - Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN. - tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred - within TPUEstimator, however when using ClusterSpec propagation in more - esoteric cluster configurations, you may need to specify the job name as a - string. - initial_infeed_sleep_secs: The number of seconds the infeed thread should - wait before enqueueing the first batch. This helps avoid timeouts for - models that require a long compilation time. - input_partition_dims: A nested list to describe the partition dims - for all the tensors from input_fn(). The structure of - input_partition_dims must match the structure of `features` and - `labels` from input_fn(). The total number of partitions must match - `num_cores_per_replica`. For example, if input_fn() returns two tensors: - images with shape [N, H, W, C] and labels [N]. - input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4 - pieces and feed into 4 TPU cores. labels tensor are directly broadcasted - to all the TPU cores since the partition dims is `None`. - Current limitations: This feature is only supported with the PER_HOST_V2 - input mode. - - Raises: - ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16. - """ - - def __new__(cls, - iterations_per_loop=2, - num_shards=None, - num_cores_per_replica=None, - per_host_input_for_training=True, - tpu_job_name=None, - initial_infeed_sleep_secs=None, - input_partition_dims=None): - - # Check iterations_per_loop. - util_lib.check_positive_integer(iterations_per_loop, - 'TPUConfig iterations_per_loop') - - # Check num_shards. - if num_shards is not None: - util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') - - if input_partition_dims is not None: - if len(input_partition_dims) != 1 and len(input_partition_dims) != 2: - raise ValueError( - 'input_partition_dims must be a list/tuple with one or two' - ' elements.') - - if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2: - raise ValueError( - 'input_partition_dims is only supported in PER_HOST_V2 mode.') - - if num_cores_per_replica is None: - raise ValueError( - 'input_partition_dims requires setting num_cores_per_replica.') - - # Check num_cores_per_replica - if num_cores_per_replica is not None: - if num_cores_per_replica not in [1, 2, 4, 8, 16]: - raise ValueError( - 'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format( - str(num_cores_per_replica))) - - # per_host_input_for_training may be True, False, or integer in [1..3]. - # Map legacy values (True, False) to numeric values. - if per_host_input_for_training is False: - per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1 - elif per_host_input_for_training is True: - per_host_input_for_training = InputPipelineConfig.PER_HOST_V1 - - # Check initial_infeed_sleep_secs. - if initial_infeed_sleep_secs: - util_lib.check_positive_integer(initial_infeed_sleep_secs, - 'TPUConfig initial_infeed_sleep_secs') - - tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config() - - return super(TPUConfig, cls).__new__( - cls, - iterations_per_loop=iterations_per_loop, - num_shards=num_shards, - num_cores_per_replica=num_cores_per_replica, - per_host_input_for_training=per_host_input_for_training, - tpu_job_name=tpu_job_name, - initial_infeed_sleep_secs=initial_infeed_sleep_secs, - input_partition_dims=input_partition_dims) - - -class RunConfig(run_config_lib.RunConfig): - """RunConfig with TPU support.""" - - def __init__(self, - tpu_config=None, - evaluation_master=None, - master=None, - cluster=None, - **kwargs): - """Constructs a RunConfig. - - Args: - tpu_config: the TPUConfig that specifies TPU-specific configuration. - evaluation_master: a string. The address of the master to use for eval. - Defaults to master if not set. - master: a string. The address of the master to use for training. - cluster: a ClusterResolver - **kwargs: keyword config parameters. - - Raises: - ValueError: if cluster is not None and the provided session_config has a - cluster_def already. - """ - super(RunConfig, self).__init__(**kwargs) - self._tpu_config = tpu_config or TPUConfig() - self._cluster = cluster - - # If user sets master and/or evaluation_master explicitly, including empty - # string '', take it. Otherwise, take the values set by parent class. - if master is not None: - if cluster is not None: - raise ValueError('Both master and cluster are set.') - self._master = master - else: - if cluster: - self._master = cluster.master() - - if evaluation_master is not None: - self._evaluation_master = evaluation_master - elif (not self._evaluation_master and - self.task_type != run_config_lib.TaskType.EVALUATOR): - # If the task type is EVALUATOR, it means some cluster manager sets the - # TF_CONFIG. In that case, we respect the configuration in TF_CONFIG. - # - # Otherwise, it means user executes the code without external cluster - # manager. For that, we optimize the user experience by setting - # evaluation_master to master, unless user overwrites it. - self._evaluation_master = self._master - - # Set the ClusterSpec to use - if cluster: - self._cluster_spec = cluster.cluster_spec() - - # Merge the cluster_def into the ConfigProto. - if self._session_config is None: # pylint: disable=access-member-before-definition - self._session_config = config_pb2.ConfigProto(allow_soft_placement=True) - if self._session_config.HasField('cluster_def'): - raise ValueError( - 'You cannot provide a ClusterResolver and ' - 'session_config.cluster_def.') - if self._cluster_spec: - self._session_config.cluster_def.CopyFrom( - self._cluster_spec.as_cluster_def()) - - def _maybe_overwrite_session_config_for_distributed_training(self): - # Overrides the parent class session_config overwrite for between-graph. TPU - # runs with in-graph, which should not have device filter. Doing nothing - # ("pass") basically disables it. - pass - - @property - def evaluation_master(self): - return self._evaluation_master - - @property - def master(self): - return self._master - - @property - def tpu_config(self): - return self._tpu_config - - @property - def cluster(self): - return self._cluster - - def replace(self, **kwargs): - if 'tpu_config' not in kwargs: - return super(RunConfig, self).replace(**kwargs) - - tpu_config = kwargs.pop('tpu_config') - new_instance = super(RunConfig, self).replace(**kwargs) - new_instance._tpu_config = tpu_config # pylint: disable=protected-access - return new_instance - - -def _get_tpu_job_name_from_tf_config(): - """Extracts the TPU job name from TF_CONFIG env variable.""" - # TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster - # spec propagation. - tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}')) - tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME) - if tpu_job_name: - logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name) - return tpu_job_name +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_config import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 672462447944b777375331d49727c4d5366cf295..b77b010cba6bf32c3b6d170bc522eebfb6a04f77 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -1,725 +1,23 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""TPU system metadata and associated tooling.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from contextlib import contextmanager -import copy - -from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment -from tensorflow.contrib.tpu.python.tpu import tpu_config -from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.platform import tf_logging as logging - - -_DEFAULT_JOB_NAME = 'tpu_worker' -_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' -_LOCAL_MASTERS = ('', 'local') -_NUM_CORES_TO_COMPUTATION_SHAPE = { - 1: [1, 1, 1], - 2: [1, 1, 2], - 4: [1, 2, 2], - 8: [2, 2, 2], - 16: [4, 2, 2], -} - - -class TPUContext(object): - """A context that holds the current configuration of the TPU computation.""" - - def __init__(self, - internal_ctx, - input_device=None, - invocation_index=None, - call_from_input_fn=True): - self._internal_ctx = internal_ctx - self._input_device = input_device - self._invocation_index = invocation_index - self._call_from_input_fn = call_from_input_fn - - def current_input_fn_deployment(self): - """The configuration of the current input_fn invocation. - - The configuration depends on `TPUConfig.per_host_input_for_training`. See - `TPUConfig` for details. - - Only set in params dict of input_fn - - Returns: - A tuple of - 1. Device spec string: String, is the current CPU host where the - input_fn is invoked. - 2. Current invocation index: Int, 0-based index of the input_fn - invocation. See next item for details. - 3. Total invocation count: Int, the total number of times to invoke the - input_fn on all CPU hosts. Each invocation will be passed with a new - `TPUContext` instance with current invocation index set properly. - 4. Total number of replicas consumed by current_invocation: Int, the - number of replicas fed by the data returned by current input_fn. For - example, for per_core input pipeline deployment - and non-model-parallelism, total invocation count is equal to - the number of cores in the system and num replicas consumed by - current invocation is 1. For per-host v2 input pipeline deployment, - total invocation count is equal to the number of hosts in the system - and num replicas consumed by current invocation is equal to number of - cores per host. - - Raises: - RuntimeError: If this method must not be called from input_fn. - """ - if not self._call_from_input_fn: - raise RuntimeError('This TPUContext instance must not be called from' - ' model_fn.') - - if self._internal_ctx.is_input_sharded_per_core(): - total_invocation_count = (self._internal_ctx.num_hosts - * self._internal_ctx.num_of_replicas_per_host) - replicas_consumed = 1 - elif self._internal_ctx.is_input_broadcast_with_iterators(): - total_invocation_count = 1 - replicas_consumed = self._internal_ctx.num_replicas - else: - total_invocation_count = self._internal_ctx.num_hosts - replicas_consumed = self._internal_ctx.num_of_replicas_per_host - return (self._input_device, self._invocation_index, - total_invocation_count, replicas_consumed) - - @property - def num_replicas(self): - """The total number of replicas. - - For non-model-parallelism, num_replicas should be the total num of TPU - cores in the system. - - Returns: - The number of replicas. - """ - return self._internal_ctx.num_replicas - - @property - def num_hosts(self): - """The number of hosts for the TPU system.""" - return self._internal_ctx.num_hosts - - @property - def current_host(self): - """The current host index for the TPU system.""" - return self._invocation_index - - @property - def num_of_replicas_per_host(self): - """The number of replicas for each host.""" - if self._internal_ctx.model_parallelism_enabled: - raise ValueError( - 'num_of_replicas_per_host is not supported for model_parallelism') - return self._internal_ctx.num_of_replicas_per_host - - @property - def device_assignment(self): - """Returns device_assignment object.""" - if self._call_from_input_fn: - raise RuntimeError('This TPUContext instance must not be called from' - ' input_fn.') - return self._internal_ctx.device_assignment - - def device_for_replica(self, replica_id): - """Returns the tuple of (CPU device and device ordinal) for replica. - - This should be used for full replicate for non-model-parallelism. - - Args: - replica_id: Int, the replica index. - - Returns: - A tuple of device spec for CPU device and int device ordinal. - """ - # Note that: For the non-model parallelism, the mapping could be - # a random permutation. The order should not matter in most cases - # as far as model is replicated to all cores in the system. - return self._internal_ctx.device_for_replica(replica_id) - - @property - def tpu_host_placement_function(self): - """Returns the TPU host place function. - - The place function takes host_id as the input and returns the TF device - for the correspoding host. - """ - - def _placement_function(host_id): - """Return the host device given host_id.""" - return self._internal_ctx.tpu_host_placement_function(host_id=host_id) - - return _placement_function - - -class _InternalTPUContext(object): - """A context holds immutable states of TPU computation. - - This immutable object holds TPUEstimator config, train/eval batch size, and - `TPUEstimator.use_tpu`, which is expected to be passed around. It also - provides utility functions, based on the current state, to determine other - information commonly required by TPU computation, such as TPU device names, - TPU hosts, shard batch size, etc. - - if eval_on_tpu is False, then execution of eval on TPU is disabled. - if eval_on_tpu is True, but use_tpu is False, a warning is issued, - and TPU execution is disabled for all modes. - - N.B. As `mode` is not immutable state in Estimator, but essential to - distinguish between TPU training and evaluation, a common usage for - _InternalTPUContext with `mode` is as follows: - ``` - with _ctx.with_mode(mode) as ctx: - if ctx.is_running_on_cpu(): - ... - ``` - """ - - def __init__(self, config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu=True): - self._config = config - self._train_batch_size = train_batch_size - self._eval_batch_size = eval_batch_size - self._predict_batch_size = predict_batch_size - self._use_tpu = use_tpu - logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu) - if not use_tpu and eval_on_tpu: - logging.warning('eval_on_tpu ignored because use_tpu is False.') - - self._eval_on_tpu = eval_on_tpu - self._model_parallelism_enabled = ( - use_tpu and config.tpu_config.num_cores_per_replica) - self._mode = None - num_cores_per_replica = config.tpu_config.num_cores_per_replica - if num_cores_per_replica: - self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[ - num_cores_per_replica] - else: - self._computation_shape = None - self._lazy_tpu_system_metadata_dict = {} # key by master address - self._lazy_device_assignment_dict = {} # key by master address - self._lazy_validation_dict = {} # key by ModeKeys - - def _assert_mode(self): - if self._mode is None: - raise RuntimeError( - '`mode` needs to be set via contextmanager `with_mode`.') - return self._mode - - @contextmanager - def with_mode(self, mode): - # NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries, - # such as _lazy_tpu_system_metadata_dict between new copy and the original - # one. Note that all lazy states stored in properties _lazy_foo are sort of - # immutable as they should be same for the process lifetime. - new_ctx = copy.copy(self) - new_ctx._mode = mode # pylint: disable=protected-access - yield new_ctx - - @property - def mode(self): - return self._assert_mode() - - def _get_master_address(self): - mode = self._assert_mode() - config = self._config - master = ( - config.master - if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master) - return master - - def _get_tpu_system_metadata(self): - """Gets the (maybe cached) TPU system metadata.""" - master = self._get_master_address() - tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) - if tpu_system_metadata is not None: - return tpu_system_metadata - - cluster_def = None - if (self._config.session_config and - self._config.session_config.cluster_def.job): - cluster_def = self._config.session_config.cluster_def - - # pylint: disable=protected-access - tpu_system_metadata = ( - tpu_system_metadata_lib._query_tpu_system_metadata( - master, - cluster_def=cluster_def, - query_topology=self.model_parallelism_enabled)) - - self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata - return tpu_system_metadata - - def _get_device_assignment(self): - """Gets the (maybe cached) TPU device assignment.""" - master = self._get_master_address() - device_assignment = self._lazy_device_assignment_dict.get(master) - if device_assignment is not None: - return device_assignment - - tpu_system_metadata = self._get_tpu_system_metadata() - - device_assignment = tpu_device_assignment.device_assignment( - tpu_system_metadata.topology, - computation_shape=self._computation_shape, - num_replicas=self.num_replicas) - - logging.info('num_cores_per_replica: %s', - str(self._config.tpu_config.num_cores_per_replica)) - logging.info('computation_shape: %s', str(self._computation_shape)) - logging.info('num_replicas: %d', self.num_replicas) - logging.info('device_assignment.topology.device_coordinates: %s', - str(device_assignment.topology.device_coordinates)) - logging.info('device_assignment.core_assignment: %s', - str(device_assignment.core_assignment)) - - self._lazy_device_assignment_dict[master] = device_assignment - return device_assignment - - @property - def model_parallelism_enabled(self): - return self._model_parallelism_enabled - - @property - def input_partition_dims(self): - return self._config.tpu_config.input_partition_dims - - @property - def device_assignment(self): - return (self._get_device_assignment() - if self._model_parallelism_enabled else None) - - @property - def num_of_cores_per_host(self): - metadata = self._get_tpu_system_metadata() - return metadata.num_of_cores_per_host - - @property - def num_cores(self): - metadata = self._get_tpu_system_metadata() - return metadata.num_cores - - @property - def num_of_replicas_per_host(self): - """Return the number of replicas per host.""" - if self.model_parallelism_enabled: - return self.num_replicas // self.num_hosts - else: - return self.num_of_cores_per_host - - @property - def num_replicas(self): - num_cores_in_system = self.num_cores - - if self.model_parallelism_enabled: - num_cores_per_replica = self._config.tpu_config.num_cores_per_replica - if num_cores_per_replica > num_cores_in_system: - raise ValueError( - 'The num of cores required by the model parallelism, specified by ' - 'TPUConfig.num_cores_per_replica, is larger than the total num of ' - 'TPU cores in the system. num_cores_per_replica: {}, num cores ' - 'in the system: {}'.format(num_cores_per_replica, - num_cores_in_system)) - - if num_cores_in_system % num_cores_per_replica != 0: - raise RuntimeError( - 'The num of cores in the system ({}) is not divisible by the num ' - 'of cores ({}) required by the model parallelism, specified by ' - 'TPUConfig.num_cores_per_replica. This should never happen!'.format( - num_cores_in_system, num_cores_per_replica)) - - return num_cores_in_system // num_cores_per_replica - else: - return num_cores_in_system - - @property - def num_hosts(self): - metadata = self._get_tpu_system_metadata() - return metadata.num_hosts - - @property - def config(self): - return self._config - - def is_input_sharded_per_core(self): - """Return true if input_fn is invoked per-core (other than per-host).""" - mode = self._assert_mode() - return (mode == model_fn_lib.ModeKeys.TRAIN and - (self._config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.PER_SHARD_V1)) - - def is_input_per_host_with_iterators(self): - """Return true if input_fn should be run in the per-host v2 config.""" - return (self._config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.PER_HOST_V2) - - def is_input_broadcast_with_iterators(self): - """Return true if input_fn should be run in the full_replicae config.""" - return (self._config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.BROADCAST) - - def is_running_on_cpu(self, is_export_mode=False): - """Determines whether the input_fn and model_fn should be invoked on CPU. - - This API also validates user provided configuration, such as batch size, - according the lazy initialized TPU system metadata. - - Args: - is_export_mode: Indicates whether the current mode is for exporting the - model, when mode == PREDICT. Only with this bool, we could - tell whether user is calling the Estimator.predict or - Estimator.export_savedmodel, which are running on TPU and CPU - respectively. Parent class Estimator does not distinguish these two. - - Returns: - bool, whether current input_fn or model_fn should be running on CPU. - - Raises: - ValueError: any configuration is invalid. - """ - - is_running_on_cpu = self._is_running_on_cpu(is_export_mode) - if not is_running_on_cpu: - self._validate_tpu_configuration() - return is_running_on_cpu - - def _is_running_on_cpu(self, is_export_mode): - """Determines whether the input_fn and model_fn should be invoked on CPU.""" - mode = self._assert_mode() - - if not self._use_tpu: - return True - - if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu: - logging.info('_is_running_on_cpu: eval_on_tpu disabled') - return True - - if is_export_mode: - return True - - return False - - @property - def global_batch_size(self): - mode = self._assert_mode() - if mode == model_fn_lib.ModeKeys.TRAIN: - return self._train_batch_size - elif mode == model_fn_lib.ModeKeys.EVAL: - return self._eval_batch_size - elif mode == model_fn_lib.ModeKeys.PREDICT: - return self._predict_batch_size - else: - return None - - @property - def batch_size_for_input_fn(self): - """Returns the shard batch size for `input_fn`.""" - global_batch_size = self.global_batch_size - - if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()): - return global_batch_size - - # On TPU - if self.is_input_sharded_per_core() or ( - self.is_input_per_host_with_iterators()): - return global_batch_size // self.num_replicas - else: - return global_batch_size // self.num_hosts - - @property - def batch_size_for_model_fn(self): - """Returns the shard batch size for `model_fn`.""" - global_batch_size = self.global_batch_size - - if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()): - return global_batch_size - - # On TPU. always sharded per shard. - return global_batch_size // self.num_replicas - - @property - def master_job(self): - """Returns the job name to use to place TPU computations on. - - Returns: - A string containing the job name, or None if no job should be specified. - - Raises: - ValueError: If the user needs to specify a tpu_job_name, because we are - unable to infer the job name automatically, or if the user-specified job - names are inappropriate. - """ - run_config = self._config - # If the user specifies the tpu_job_name, use that. - if run_config.tpu_config.tpu_job_name: - return run_config.tpu_config.tpu_job_name - - # The tpu job is determined by the run_config. Right now, this method is - # required as tpu_config is not part of the RunConfig. - mode = self._assert_mode() - master = ( - run_config.evaluation_master - if mode == model_fn_lib.ModeKeys.EVAL else run_config.master) - if master in _LOCAL_MASTERS: - return None - - if (not run_config.session_config or - not run_config.session_config.cluster_def.job): - return _DEFAULT_JOB_NAME - cluster_def = run_config.session_config.cluster_def - job_names = set([job.name for job in cluster_def.job]) - if _DEFAULT_JOB_NAME in job_names: - # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. - raise ValueError('Currently, tpu_worker is not an allowed job name.') - if len(job_names) == 1: - return cluster_def.job[0].name - if len(job_names) == 2: - if _DEFAULT_COORDINATOR_JOB_NAME in job_names: - job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) - return job_names.pop() - # TODO(b/67716447): Include more sophisticated heuristics. - raise ValueError( - 'Could not infer TPU job name. Please specify a tpu_job_name as part ' - 'of your TPUConfig.') - - @property - def tpu_host_placement_function(self): - """Returns the TPU host place function.""" - - master = self.master_job - - def _placement_function(_sentinal=None, replica_id=None, host_id=None): # pylint: disable=invalid-name - """Return the host device given replica_id or host_id.""" - assert _sentinal is None - if replica_id is not None and host_id is not None: - raise RuntimeError( - 'replica_id and host_id can have only one non-None value.') - - if master is None: - return '/replica:0/task:0/device:CPU:0' - else: - if replica_id is not None: - if self.model_parallelism_enabled: - return self.device_assignment.host_device( - replica=replica_id, job=master) - else: - host_id = replica_id / self.num_of_cores_per_host - - return '/job:%s/task:%d/device:CPU:0' % (master, host_id) - - return _placement_function - - @property - def tpu_device_placement_function(self): - """Returns a TPU device placement Fn.""" - master = self.master_job - job_device = '' if master is None else ('/job:%s' % master) - - def _placement_function(i): - if self.model_parallelism_enabled: - return self.device_assignment.tpu_device(replica=i, job=master) - else: - num_of_cores_per_host = self.num_of_cores_per_host - host_id = i / num_of_cores_per_host - ordinal_id = i % num_of_cores_per_host - return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id) - - return _placement_function - - def tpu_ordinal_function(self, host_id): - """Returns the TPU ordinal fn.""" - - def _tpu_ordinal_function(shard_index_in_host): - """Return the TPU ordinal associated with a shard. - - Required because the enqueue ops are placed on CPU. - - Args: - shard_index_in_host: the shard index - - Returns: - The ordinal of the TPU device the shard's infeed should be placed on. - """ - if self.model_parallelism_enabled: - # We put both enqueue/dequeue ops at tpu.core(0) in each replica. - replica = self.device_assignment.lookup_replicas(host_id, - 0)[shard_index_in_host] - return self.device_assignment.tpu_ordinal(replica=replica) - else: - return shard_index_in_host % self.num_of_cores_per_host - - return _tpu_ordinal_function - - def _validate_tpu_configuration(self): - """Validates the configuration based on the TPU system metadata.""" - mode = self._assert_mode() - if self._lazy_validation_dict.get(mode): - return - - # All following information is obtained from TPU system metadata. - num_cores = self.num_cores - num_replicas = self.num_replicas - num_hosts = self.num_hosts - - if not num_cores: - tpu_system_metadata = self._get_tpu_system_metadata() - raise RuntimeError( - 'Cannot find any TPU cores in the system. Please double check ' - 'Tensorflow master address and TPU worker(s). Available devices ' - 'are {}.'.format(tpu_system_metadata.devices)) - - if self._config.tpu_config.num_shards: - user_provided_num_replicas = self._config.tpu_config.num_shards - if user_provided_num_replicas != num_replicas: - message = ( - 'TPUConfig.num_shards is not set correctly. According to TPU ' - 'system metadata for Tensorflow master ({}): num_replicas should ' - 'be ({}), got ({}). For non-model-parallelism, num_replicas should ' - 'be the total num of TPU cores in the system. For ' - 'model-parallelism, the total number of TPU cores should be ' - 'num_cores_per_replica * num_replicas. Please set it ' - 'accordingly or leave it as `None`'.format( - self._get_master_address(), num_replicas, - user_provided_num_replicas)) - - raise ValueError(message) - - if self._config.tpu_config.num_cores_per_replica: - num_cores_per_replica = self._config.tpu_config.num_cores_per_replica - num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host - if num_cores_per_replica > num_cores_per_host: - raise ValueError( - 'The num of cores required by the model parallelism, specified by ' - 'TPUConfig.num_cores_per_replica, is larger than the ' - 'num_cores_per_host. num_cores_per_replica: {}, ' - 'num_cores_per_host: {}'.format(num_cores_per_replica, - num_cores_per_host)) - - if mode == model_fn_lib.ModeKeys.TRAIN: - if (self._train_batch_size % num_replicas != 0 and - not self.is_input_broadcast_with_iterators()): - raise ValueError( - 'train batch size {} must be divisible by number of replicas {}' - .format(self._train_batch_size, num_replicas)) - - elif mode == model_fn_lib.ModeKeys.EVAL: - if self._eval_batch_size is None: - raise ValueError( - 'eval_batch_size in TPUEstimator constructor cannot be `None`' - 'if .evaluate is running on TPU.') - if (self._eval_batch_size % num_replicas != 0 and - not self.is_input_broadcast_with_iterators()): - raise ValueError( - 'eval batch size {} must be divisible by number of replicas {}' - .format(self._eval_batch_size, num_replicas)) - if num_hosts > 1 and not self.is_input_broadcast_with_iterators(): - raise ValueError( - 'TPUEstimator.evaluate should be running on single TPU' - ' instead of a Pod.') - else: - assert mode == model_fn_lib.ModeKeys.PREDICT - if self._predict_batch_size is None: - raise ValueError( - 'predict_batch_size in TPUEstimator constructor should not be ' - '`None` if .predict is running on TPU.') - if (self._predict_batch_size % num_replicas != 0 and - not self.is_input_broadcast_with_iterators()): - raise ValueError( - 'predict batch size {} must be divisible by number of replicas {}' - .format(self._predict_batch_size, num_replicas)) - if num_hosts > 1 and not self.is_input_broadcast_with_iterators(): - raise ValueError( - 'TPUEstimator.predict should be running on single TPU worker. ' - 'got {}.'.format(num_hosts)) - - # Record the state "validated" into lazy dictionary. - self._lazy_validation_dict[mode] = True - - def device_for_replica(self, replica_id): - """Returns the tuple of (CPU device and device ordinal) for replica. - - This should be used for full replicate for non-model-parallelism. - - Args: - replica_id: Int, the replica index. - - Returns: - A tuple of device spec for CPU device and int device ordinal. - """ - master = self.master_job - - if self.model_parallelism_enabled: - return (self.device_assignment.host_device( - replica=replica_id, job=master), - self.device_assignment.tpu_ordinal(replica=replica_id)) - - job_device = '' if master is None else ('/job:%s' % master) - - num_of_replicas_per_host = self.num_of_replicas_per_host - host_id = replica_id / num_of_replicas_per_host - ordinal_id = replica_id % num_of_replicas_per_host - - host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id) - return (host_device, ordinal_id) - - -class _OneCoreTPUContext(_InternalTPUContext): - """Special _InternalTPUContext for one core usage.""" - - def __init__(self, config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu): - - super(_OneCoreTPUContext, self).__init__( - config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu) - - def _get_tpu_system_metadata(self): - """Gets the (maybe cached) TPU system metadata.""" - master = self._get_master_address() - tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) - if tpu_system_metadata is not None: - return tpu_system_metadata - - tpu_system_metadata = ( - tpu_system_metadata_lib._TPUSystemMetadata( # pylint: disable=protected-access - num_cores=1, - num_hosts=1, - num_of_cores_per_host=1, - topology=None, - devices=[])) - - self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata - return tpu_system_metadata - - -def _get_tpu_context(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu): - """Returns an instance of `_InternalTPUContext`.""" - - if (config.tpu_config.num_shards == 1 and - config.tpu_config.num_cores_per_replica is None): - logging.warning( - 'Setting TPUConfig.num_shards==1 is an unsupported behavior. ' - 'Please fix as soon as possible (leaving num_shards as None.)') - return _OneCoreTPUContext(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu) - - return _InternalTPUContext(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_context import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py index ccba8a46c7cad0337119672e02314684f4451479..cb38a8f1a6bee3c2adfbefc203c1d143303c3368 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py @@ -1,10 +1,10 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,1099 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TPU embedding APIs.""" +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import copy -import math -import re -import six - -from tensorflow.contrib.framework.python.framework import experimental -from tensorflow.contrib.tpu.ops import gen_tpu_ops -from tensorflow.contrib.tpu.proto import tpu_embedding_configuration_pb2 as elc -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import partitioned_variables -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables - -TRAINING = elc.TPUEmbeddingConfiguration.TRAINING -INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE - -# TODO(shizhiw): A better interface is to make `num_hosts` and -# `num_cores_per_host` optional parameters for `TPUEmbedding` -# constructor. Usually they can be automatically detected, but -# user can also specify them for debugging (b/112112496). -# Auto-detection can be done with `tpu_system_metadata.py`. -_MASTER_JOB = 'tpu_worker' -_HOST_PATTERN = '/job:tpu_worker/task:{}/device:CPU:0' -_NUM_CORES_PER_HOST = 8 - -_TEST_MASTER_JOB = None -_TEST_HOST = '/replica:0/task:0/device:CPU:0' -_TEST_NUM_CORES_PER_HOST = 2 - - -class TableConfig( - collections.namedtuple( - 'TableConfig', - ['vocabulary_size', 'dimension', 'initializer', 'combiner'])): - """Embedding table configuration.""" - - @experimental - def __new__(cls, - vocabulary_size, - dimension, - initializer=None, - combiner='mean'): - """Embedding table configuration. - - Args: - vocabulary_size: Number of vocabulary (/rows) in the table. - dimension: The embedding dimension. - initializer: A variable initializer function to be used in embedding - variable initialization. If not specified, defaults to - `tf.truncated_normal_initializer` with mean `0.0` and standard deviation - `1/sqrt(dimension)`. - combiner: A string specifying how to reduce if there are multiple entries - in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with - 'mean' the default. 'sqrtn' often achieves good accuracy, in particular - with bag-of-words columns. For more information, see - `tf.nn.embedding_lookup_sparse`. - - Returns: - `TableConfig`. - - Raises: - ValueError: if `vocabulary_size` is not positive integer. - ValueError: if `dimension` is not positive integer. - ValueError: if `initializer` is specified and is not callable. - ValueError: if `combiner` is not supported. - """ - if not isinstance(vocabulary_size, int) or vocabulary_size < 1: - raise ValueError('Invalid vocabulary_size {}.'.format(vocabulary_size)) - - if not isinstance(dimension, int) or dimension < 1: - raise ValueError('Invalid dimension {}.'.format(dimension)) - - if (initializer is not None) and (not callable(initializer)): - raise ValueError('initializer must be callable if specified.') - if initializer is None: - initializer = init_ops.truncated_normal_initializer( - mean=0.0, stddev=1 / math.sqrt(dimension)) - - if combiner not in ('mean', 'sum', 'sqrtn'): - raise ValueError('Invalid combiner {}'.format(combiner)) - - return super(TableConfig, cls).__new__(cls, vocabulary_size, dimension, - initializer, combiner) - - -# TODO(shizhiw): Factor `use_gradient_accumulation` and -# `pipeline_execution_with_tensor_core` out of `_OptimizationParameters`. -class _OptimizationParameters(object): - """Parameters common to all optimizations.""" - - def __init__(self, learning_rate, use_gradient_accumulation, - pipeline_execution_with_tensor_core): - self.learning_rate = learning_rate - self.use_gradient_accumulation = use_gradient_accumulation - self.pipeline_execution_with_tensor_core = ( - pipeline_execution_with_tensor_core) - - -class AdagradParameters(_OptimizationParameters): - """Optimization parameters for Adagrad.""" - - def __init__(self, learning_rate, initial_accumulator, - use_gradient_accumulation=False, - pipeline_execution_with_tensor_core=True): - """Optimization parameters for Adagrad. - - Args: - learning_rate: used for updating embedding table. - initial_accumulator: initial accumulator for Adagrad. - use_gradient_accumulation: setting this to `True` makes embedding - gradients calculation more accurate but slower. Please see - `optimization_parameters.proto` for details. - for details. - pipeline_execution_with_tensor_core: setting this to `True` makes training - faster, but trained model will be different if step N and step N+1 - involve the same set of embedding ID. Please see - `tpu_embedding_configuration.proto` for details. - """ - super(AdagradParameters, self).__init__(learning_rate, - use_gradient_accumulation, - pipeline_execution_with_tensor_core) - self.initial_accumulator = initial_accumulator - - -class AdamParameters(_OptimizationParameters): - """Optimization parameters for Adam.""" - - def __init__(self, learning_rate, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - lazy_adam=True, - sum_inside_sqrt=True, - use_gradient_accumulation=False, - pipeline_execution_with_tensor_core=True): - """Optimization parameters for Adam. - - Args: - learning_rate: a floating point value. The learning rate. - beta1: A float value. - The exponential decay rate for the 1st moment estimates. - beta2: A float value. - The exponential decay rate for the 2nd moment estimates. - epsilon: A small constant for numerical stability. - lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. - Please see `optimization_parameters.proto` for details. - sum_inside_sqrt: This improves training speed. Please see - `optimization_parameters.proto` for details. - use_gradient_accumulation: setting this to `True` makes embedding - gradients calculation more accurate but slower. Please see - `optimization_parameters.proto` for details. - for details. - pipeline_execution_with_tensor_core: setting this to `True` makes training - faster, but trained model will be different if step N and step N+1 - involve the same set of embedding ID. Please see - `tpu_embedding_configuration.proto` for details. - """ - super(AdamParameters, self).__init__(learning_rate, - use_gradient_accumulation, - pipeline_execution_with_tensor_core) - self.beta1 = beta1 - self.beta2 = beta2 - self.epsilon = epsilon - self.lazy_adam = lazy_adam - self.sum_inside_sqrt = sum_inside_sqrt - - -class StochasticGradientDescentParameters(_OptimizationParameters): - """Optimization parameters for stochastic gradient descent. - - Args: - learning_rate: a floating point value. The learning rate. - use_gradient_accumulation: setting this to `True` makes embedding - gradients calculation more accurate but slower. Please see - `optimization_parameters.proto` for details. - pipeline_execution_with_tensor_core: setting this to `True` makes training - faster, but trained model will be different if step N and step N+1 - involve the same set of embedding ID. Please see - `tpu_embedding_configuration.proto` for details. - """ - - def __init__(self, learning_rate, use_gradient_accumulation=False, - pipeline_execution_with_tensor_core=True): - super(StochasticGradientDescentParameters, self).__init__( - learning_rate, use_gradient_accumulation, - pipeline_execution_with_tensor_core) - - -class TPUEmbedding(object): - """API for using TPU for embedding. - - Example: - ``` - table_config_user = tpu_embedding.TableConfig( - vocabulary_size=4, dimension=2, - initializer=initializer, combiner='mean') - table_to_config_dict = {'video': table_config_video, - 'user': table_config_user} - feature_to_table_dict = {'watched': 'video', - 'favorited': 'video', - 'friends': 'user'} - batch_size = 4 - num_hosts = 1 - optimization_parameters = tpu_embedding.AdagradParameters(1., 1.) - mode = tpu_embedding.TRAINING - embedding = tpu_embedding.TPUEmbedding( - table_to_config_dict, feature_to_table_dict, - batch_size, num_hosts, mode, optimization_parameters) - - batch_size_per_core = embedding.batch_size_per_core - sparse_features_list = [] - for host in hosts: - with ops.device(host): - for _ in range(embedding.num_cores_per_host): - sparse_features = {} - sparse_features['watched'] = sparse_tensor.SparseTensor(...) - sparse_features['favorited'] = sparse_tensor.SparseTensor(...) - sparse_features['friends'] = sparse_tensor.SparseTensor(...) - sparse_features_list.append(sparse_features) - - enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list) - - def computation(): - activations = embedding.get_activations() - loss = compute_loss(activations) - - base_optimizer = gradient_descent.GradientDescentOptimizer( - learning_rate=1) - cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer( - base_optimizer) - - train_op = cross_shard_optimizer.minimize(loss) - # `train_op` and `send_gradients_op` must happen in order. - with ops.control_dependencies([train_op]): - send_gradients_op = embedding.generate_send_gradients_op() - with ops.control_dependencies([send_gradients_op]): - loss = array_ops.identity(loss) - - loss = tpu.shard(computation, - num_shards=embedding.num_cores) - - with self.test_session() as sess: - sess.run(tpu.initialize_system(embedding_config= - embedding.config_proto)) - sess.run(variables.global_variables_initializer()) - sess.run(embedding.init_ops) - sess.run(enqueue_ops) - loss_val = sess.run(loss) - ``` - """ - - # TODO(shizhiw): Instead of `feature_to_table_dict` which maps to table - # name, consider `feature_to_config_dict` which maps to `FeatureConfig`. - # `FeatureConfig` could have fields other than table name. For example, it - # could have a field to indicate that the feature should not be used to - # update embedding table (cr/204852758, cr/204940540). Also, this can support - # different combiners for different features within the same table. - # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it - # to `FeatureConfig`? - - # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and - # `feature_to_table_dict` lists of `TableSpec` and `FeatureSpec` respectively? - - # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate - # for-loops around construction of inputs. - - # `optimization_parameter` applies to all tables. If the need arises, - # we can add `optimization_parameters` to `TableConfig` to override this - # global setting. - @experimental - def __init__(self, - table_to_config_dict, - feature_to_table_dict, - batch_size, - num_hosts, - mode, - optimization_parameters=None, - tpu_embedding_test=False): - """API for using TPU for embedding lookups. - - Args: - table_to_config_dict: A dictionary mapping from string of table name to - `TableConfig`. Table refers to an embedding table, e.g. `params` - argument to `tf.nn.embedding_lookup_sparse()`. - feature_to_table_dict: A dictionary mapping from string of feature name - to string of table name. Feature refers to ids to lookup in embedding - table, e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`. - batch_size: An `int` representing the global batch size. - num_hosts: An `int` representing the number of TPU hosts. - mode: `TRAINING` or `INFERENCE`. - optimization_parameters: `AdagradParameters`, `AdamParameters`, - `Stochasticgradientdescentparameters`. Must be set in training and must - be `None` in inference. - tpu_embedding_test: A `bool`. Only used for testing. - - Raises: - ValueError: if any input is invalid. - """ - _validate_table_to_config_dict(table_to_config_dict) - # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`. - self._table_to_config_dict = _create_ordered_dict(table_to_config_dict) - self._combiners = _create_combiners(self._table_to_config_dict) - - _validate_feature_to_table_dict(table_to_config_dict, feature_to_table_dict) - self._feature_to_table_dict = _create_ordered_dict(feature_to_table_dict) - self._table_to_features_dict = _create_table_to_features_dict( - self._feature_to_table_dict) - - self._batch_size = batch_size - - if tpu_embedding_test: - self._num_hosts = 1 - self._hosts = [_TEST_HOST] - self._num_cores_per_host = _TEST_NUM_CORES_PER_HOST - else: - self._num_hosts = num_hosts - self._hosts = [_HOST_PATTERN.format(i) for i in range(self._num_hosts)] - self._num_cores_per_host = _NUM_CORES_PER_HOST - self._num_cores = self._num_cores_per_host * self._num_hosts - - _validate_batch_size(self._batch_size, self._num_cores) - self._batch_size_per_core = self._batch_size // self._num_cores - - self._init_ops = [] - - # TODO(shizhiw): remove `mode`? - if mode == TRAINING: - _validate_optimization_parameters(optimization_parameters) - self._optimization_parameters = optimization_parameters - elif mode == INFERENCE: - if optimization_parameters is not None: - raise ValueError('`optimization_parameters` should be `None` ' - 'for inference mode.') - self._optimization_parameters = ( - StochasticGradientDescentParameters(1.)) - else: - raise ValueError('`mode` only supports {} and {}; got {}.' - .format(TRAINING, INFERENCE, mode)) - self._mode = mode - - # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler` - # and create special handler for inference that inherits from - # StochasticGradientDescentHandler with more user-friendly error message - # on get_slot(). - self._optimizer_handler = _get_optimization_handler( - self._optimization_parameters) - - dummy_table_variables_init_op = self._create_dummy_table_variables() - self._init_ops.append(dummy_table_variables_init_op) - - self._config_proto = self._create_config_proto() - - self._create_variables_and_ops() - self._init_ops.extend(self._load_parameters_ops) - - @property - def hosts(self): - """A list of device names for CPU hosts. - - Returns: - A list of device names for CPU hosts. - """ - return self._hosts - - # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and - # to be consistent with `tpu_embedding_configuration.proto`. - @property - def num_cores_per_host(self): - """Number of TPU cores on a CPU host. - - Returns: - Number of TPU cores on a CPU host. - """ - return self._num_cores_per_host - - @property - def num_cores(self): - """Total number of TPU cores on all hosts. - - Returns: - Total number of TPU cores on all hosts. - """ - return self._num_cores - - @property - def batch_size_per_core(self): - """Batch size for each TPU core. - - The sparse tensors in `sparse_features_list` to `generate_enqueue_ops` - must have batch dimension equal to this. - - Returns: - Batch size for each TPU core. - """ - return self._batch_size_per_core - - @property - def config_proto(self): - """Create embedding config proto for `tpu.initialize_system()`. - - Returns: - an `TPUEmbeddingConfiguration` proto describing the desired - configuration of the hardware embedding lookup tables, which - is passed to `tpu.initialize_system()`. - """ - return self._config_proto - - @property - def init_ops(self): - """Initialization ops for TPU embedding. - - It must be called after all global variables have been initialized, - i.e. after `global_variables_initializer()`, as it loads embedding - tables into TPU. - - Returns: - A list of ops. - """ - return self._init_ops - - # TODO(shizhiw): get table variables the same way as getting slot variables. - @property - def table_to_table_variables_dict(self): - return copy.copy(self._table_to_table_variables_dict) - - def get_slot_names(self): - """Return a list of the names of slots created by `TPUEmbedding`.""" - return self._optimizer_handler.get_slot_names() - - def get_slot(self, table, name): - """Return a slot named `name` create for `table` by `TPUEmbedding`.""" - return self._optimizer_handler.get_slot(table, name) - - # TODO(shizhiw): expose load to user too? - @property - def retrieve_parameters_ops(self): - return self._retrieve_parameters_ops - - def _create_config_proto(self): - """Create `TPUEmbeddingConfiguration`.""" - config_proto = elc.TPUEmbeddingConfiguration() - for table in self._table_to_config_dict: - table_descriptor = config_proto.table_descriptor.add() - table_descriptor.name = table - - table_config = self._table_to_config_dict[table] - table_descriptor.vocabulary_size = table_config.vocabulary_size - table_descriptor.dimension = table_config.dimension - - features_for_table = self._table_to_features_dict[table] - table_descriptor.num_features = len(features_for_table) - - table_descriptor.optimization_parameters.learning_rate.constant = ( - self._optimization_parameters.learning_rate) - table_descriptor.optimization_parameters.use_gradient_accumulation = ( - self._optimization_parameters.use_gradient_accumulation) - self._optimizer_handler.set_optimization_parameters(table_descriptor) - - config_proto.mode = self._mode - config_proto.batch_size_per_tensor_core = self._batch_size_per_core - config_proto.num_hosts = self._num_hosts - config_proto.num_tensor_cores = self._num_cores - config_proto.sharding_strategy = elc.TPUEmbeddingConfiguration.DIV_DEFAULT - config_proto.pipeline_execution_with_tensor_core = ( - self._optimization_parameters.pipeline_execution_with_tensor_core) - - return config_proto - - def _create_variables_and_ops(self): - """Create embedding variables and return ops to load them into TPU.""" - self._load_parameters_ops = [] - self._retrieve_parameters_ops = [] - self._table_to_table_variables_dict = {} - for table in self._table_to_config_dict: - device_fn = _create_device_fn(self._hosts) - with ops.device(device_fn): - # TODO(shizhiw): allow user to specify variable name so that - # they could make the name consistent with CPU etc. - variable_name = table - table_variables = _create_partitioned_variables( - name=variable_name, - num_hosts=self._num_hosts, - vocabulary_size=self._table_to_config_dict[table].vocabulary_size, - embedding_dimension=self._table_to_config_dict[table].dimension, - initializer=self._table_to_config_dict[table].initializer, - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - self._table_to_table_variables_dict[table] = table_variables - - self._optimizer_handler.create_variables_and_ops( - table, variable_name, self._num_hosts, - self._table_to_config_dict[table], table_variables, - self._load_parameters_ops, self._retrieve_parameters_ops) - - def _create_dummy_table_variables(self): - """Create dummy embedding table variables. - - The sole purpose of these dummy variables are to trigger gradient - calcuation wrt them so that the gradients wrt activation can be captured - and later sent to TPU embedding. - - Returns: - Initializer for these variables. - - Raises: - RuntimeError: if collection to store gradients already exists and is not - empty. - """ - self._dummy_table_variables = [] - # TODO(shizhiw): remove table id. - for table_id, table in enumerate(self._table_to_features_dict): - self._dummy_table_variables.append( - variable_scope.get_variable( - 'tpu_embedding_dummy_table_variable_%s' % table, - dtype=dtypes.float32, - shape=[1], - use_resource=True, - trainable=True, - # TODO(shizhiw): Remove these dummy variables as - # tensorflow optimizer creates slot variable for them which - # is undesirable. - # e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1}. - # Explicitly specifying collections prevents this variable from - # being added to the GLOBAL_VARIABLES collection, so that Saver() - # ignores it. - collections=['tpu_embedding_dummy_table_variables'])) - - g = ops.get_default_graph() - table_gradients = g.get_collection_ref( - 'tpu_embedding_gradients_table_%d' % table_id) - if table_gradients: - raise RuntimeError( - 'tpu_embedding_gradients_table_%d is not empty.' % table_id) - table_gradients.extend([None] * len(self._table_to_features_dict[table])) - - return variables.variables_initializer( - self._dummy_table_variables, - name='tpu_embedding_dummy_table_variables_init') - - def generate_enqueue_ops(self, sparse_features_list): - """Generate enqueue ops. - - Args: - sparse_features_list: a list of dictionary mapping from string - of feature names to sparse tensor. Each dictionary is for one - TPU core. Dictionaries for the same core should be contiguous - on the list. - - Returns: - Ops to enqueue to TPU for embedding. - """ - self._validate_generate_enqueue_ops_sparse_features_list( - sparse_features_list) - return [ - self._generate_enqueue_op( - sparse_features, device_ordinal=i % self._num_cores_per_host) - for i, sparse_features in enumerate(sparse_features_list) - ] - - def _validate_generate_enqueue_ops_sparse_features_list( - self, sparse_features_list): - """Validate `sparse_features_list`.""" - if len(sparse_features_list) != self._num_cores: - raise ValueError('Length of `sparse_features_list` should match the ' - 'number of cores; ' - '`len(sparse_features_list)` is {}, ' - 'number of cores is {}.'.format( - len(sparse_features_list), self._num_cores)) - - feature_set = set(self._feature_to_table_dict.keys()) - contiguous_device = None - for i, sparse_features in enumerate(sparse_features_list): - used_feature_set = set(sparse_features.keys()) - - # Check features are valid. - missing_feature_set = feature_set - used_feature_set - if missing_feature_set: - raise ValueError('`sparse_features_list[{}]` misses a feature that is ' - 'in `feature_to_config_dict`: {}.'.format( - i, missing_feature_set)) - - extra_feature_set = used_feature_set - feature_set - if extra_feature_set: - raise ValueError('`sparse_features_list[{}]` has a feature that is not ' - 'in `feature_to_config_dict`: {}.'.format( - i, extra_feature_set)) - - device = None - device_feature = None - for feature, tensor in six.iteritems(sparse_features): - if not isinstance(tensor, sparse_tensor.SparseTensor): - raise ValueError('`sparse_features_list[{}]` has a feature that is ' - 'not mapped to `SparseTensor`. ' - '`feature`: {}, type: {}'.format( - i, feature, type(tensor))) - - # Check all features are on the same device. - if device is None: - device = tensor.op.device - device_feature = feature - else: - if device != tensor.op.device: - raise ValueError('Devices are different between features in ' - '`sparse_features_list[{}]`; ' - 'devices: {}, {}; features: {}, {}.'.format( - i, device, tensor.op.device, feature, - device_feature)) - - if i % self._num_cores_per_host: - if device != contiguous_device: - raise ValueError('We expect the `sparse_features` which are on the ' - 'same host to be contiguous in ' - '`sparse_features_list`, ' - '`sparse_features_list[{}]` is on device {}, ' - 'but is expected to be on device {}.'.format( - i, device, contiguous_device)) - else: - contiguous_device = device - - def _generate_enqueue_op(self, sparse_features, device_ordinal): - with ops.colocate_with(list(sparse_features.values())[0]): - sample_idcs, embedding_idcs, aggregation_weights = ( - self._format_for_tpu_embedding_sparse_batch(sparse_features)) - return tpu_ops.enqueue_tpu_embedding_sparse_batch( - sample_idcs, - embedding_idcs, - aggregation_weights, - combiners=self._combiners, - device_ordinal=device_ordinal) - - def _format_for_tpu_embedding_sparse_batch(self, sparse_features): - """Format sparse features for `enqueue_tpu_embedding_sparse_batch()`. - - Args: - sparse_features: a `Dict` of `SparseTensor`s for embedding. - - Returns: - Arguments for `enqueue_tpu_embedding_sparse_batch()`. - """ - - sample_idcs, embedding_idcs, aggregation_weights = list(), list(), list() - for table in self._table_to_features_dict: - sample_t, indices_t, weights_t = list(), list(), list() - - features = self._table_to_features_dict[table] - for i, feature in enumerate(features): - tensor = sparse_features[feature] - sample_indices = tensor.indices[:, 0] - embedding_indices = tensor.values - weights = array_ops.ones_like(embedding_indices) - sample_t.append(i * self._batch_size_per_core + sample_indices) - indices_t.append(embedding_indices) - weights_t.append(weights) - - sample_idcs.append( - math_ops.cast(array_ops.concat(sample_t, axis=0), dtype=dtypes.int32)) - embedding_idcs.append( - math_ops.cast( - array_ops.concat(indices_t, axis=0), dtype=dtypes.int32)) - aggregation_weights.append( - math_ops.cast( - array_ops.concat(weights_t, axis=0), dtype=dtypes.float32)) - - return sample_idcs, embedding_idcs, aggregation_weights - - def get_activations(self): - """Get activations for features. - - This should be called within `computation` that is passed to - `tpu.replicate` and friends. - - Returns: - A dictionary mapping from `String` of feature name to `Tensor` - of activation. - """ - recv_activations = tpu_ops.recv_tpu_embedding_activations( - num_outputs=len(self._table_to_config_dict), - config=self._config_proto.SerializeToString()) - - activations = collections.OrderedDict() - for table_id, table in enumerate(self._table_to_features_dict): - features = self._table_to_features_dict[table] - for lookup_id, feature in enumerate(features): - start_row = lookup_id * self._batch_size_per_core - end_row = start_row + self._batch_size_per_core - activations[feature] = gen_tpu_ops.tpu_embedding_activations( - self._dummy_table_variables[table_id], - recv_activations[table_id][start_row:end_row, :], - table_id=table_id, - lookup_id=lookup_id) - return activations - - # TODO(shizhiw): Make `gradient_multiplier` per feature. Setting it to 0 would - # have the effect of `tf.stop_gradients()`. - # TODO(shizhiw): Consider alternative ways to capture gradients wrt embedding - # layer outputs to remove `_dummy_table_variables`, - # `_embedding_activation_grad` and `tpu_embedding_gradients_table_%d'. - def generate_send_gradients_op(self, gradient_multipliers=None): - """Retrieve gradients from collections and send them to TPU embedding. - - Args: - gradient_multipliers: None, or dict mapping table names to gradient - multiplier Tensors. - - Returns: - SendTPUEmbeddingGradients Op. - - Raises: - ValueError: If required gradients have not been defined. - RuntimeError: If `mode` is not `TRAINING`. - """ - if self._mode != TRAINING: - raise RuntimeError('Only in training mode gradients need to ' - 'be sent to TPU embedding; got mode {}.' - .format(self._mode)) - - g = ops.get_default_graph() - gradients = list() - for table_id, table in enumerate(self._table_to_config_dict): - table_gradients = g.get_collection( - 'tpu_embedding_gradients_table_%d' % table_id) - if any(gradient is None for gradient in table_gradients): - raise ValueError( - 'Table {}/{} has undefined gradients: this is probably because the ' - 'model asked TPUEmbedding to compute activations that were not ' - 'used.'.format(table_id, table)) - concat_table_grads = array_ops.concat(table_gradients, axis=0) - if gradient_multipliers is not None: - concat_table_grads *= gradient_multipliers[table.name] - gradients.append(concat_table_grads) - - return tpu_ops.send_tpu_embedding_gradients( - inputs=gradients, config=self.config_proto.SerializeToString()) - - -def _validate_table_to_config_dict(table_to_config_dict): - """Validate `table_to_config_dict`.""" - for k, v in six.iteritems(table_to_config_dict): - if not isinstance(v, TableConfig): - raise ValueError('Value of `table_to_config_dict` must be of type ' - '`TableConfig`, got {} for {}.'.format(type(v), k)) - - -def _validate_feature_to_table_dict(table_to_config_dict, - feature_to_table_dict): - """Validate `feature_to_table_dict`.""" - used_table_set = set(feature_to_table_dict.values()) - table_set = set(table_to_config_dict.keys()) - - unused_table_set = table_set - used_table_set - if unused_table_set: - raise ValueError('`table_to_config_dict` specifies table that is not ' - 'used in `feature_to_table_dict`: {}.' - .format(unused_table_set)) - - extra_table_set = used_table_set - table_set - if extra_table_set: - raise ValueError('`feature_to_table_dict` refers to a table that is not ' - 'specified in `table_to_config_dict`: {}.' - .format(extra_table_set)) - - -def _validate_batch_size(batch_size, num_cores): - if batch_size % num_cores: - raise ValueError('`batch_size` is not a multiple of number of ' - 'cores. `batch_size`={}, `_num_cores`={}.'.format( - batch_size, num_cores)) - - -def _validate_optimization_parameters(optimization_parameters): - if not isinstance(optimization_parameters, _OptimizationParameters): - raise ValueError('`optimization_parameters` must inherit from ' - '`_OptimizationPramaters`. ' - '`type(optimization_parameters)`={}'.format( - type(optimization_parameters))) - - -class _OptimizerHandler(object): - """Interface class for handling optimizer specific logic.""" - - def __init__(self, optimization_parameters): - self._optimization_parameters = optimization_parameters - - def set_optimization_parameters(self, table_descriptor): - raise NotImplementedError() - - def create_variables_and_ops(self, table, variable_name): - raise NotImplementedError() - - def get_slot_names(self): - raise NotImplementedError() - - def get_slot(self, table, name): - raise NotImplementedError() - - -class _AdagradHandler(_OptimizerHandler): - """Handles Adagrad specific logic.""" - - def __init__(self, optimization_parameters): - super(_AdagradHandler, self).__init__(optimization_parameters) - self._table_to_accumulator_variables_dict = {} - - def set_optimization_parameters(self, table_descriptor): - table_descriptor.optimization_parameters.adagrad.SetInParent() - - def create_variables_and_ops(self, table, variable_name, num_hosts, - table_config, table_variables, - load_parameters_ops, retrieve_parameters_ops): - optimizer_name = 'Adagrad' - accumulator_initializer = init_ops.constant_initializer( - self._optimization_parameters.initial_accumulator) - accumulator_variables = _create_partitioned_variables( - name='%s/%s' % (variable_name, optimizer_name), - num_hosts=num_hosts, - vocabulary_size=table_config.vocabulary_size, - embedding_dimension=table_config.dimension, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - initializer=accumulator_initializer) - - self._table_to_accumulator_variables_dict[table] = accumulator_variables - for host_id, table_variable, accumulator_variable in (zip( - range(num_hosts), table_variables, accumulator_variables)): - with ops.colocate_with(table_variable): - load_parameters_op = ( - tpu_ops.load_tpu_embedding_adagrad_parameters( - parameters=table_variable, - accumulators=accumulator_variable, - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieved_table, retrieved_accumulator = ( - tpu_ops.retrieve_tpu_embedding_adagrad_parameters( - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieve_parameters_op = control_flow_ops.group( - state_ops.assign(table_variable, retrieved_table), - state_ops.assign(accumulator_variable, retrieved_accumulator)) - - load_parameters_ops.append(load_parameters_op) - retrieve_parameters_ops.append(retrieve_parameters_op) - - def get_slot_names(self): - return ['accumulator'] - - def get_slot(self, table, name): - if name not in self.get_slot_names(): - raise ValueError('Adagrad has {} as slot names; got {}.' - .format(self.get_slot_names(), name)) - return self._table_to_accumulator_variables_dict[table] - - -class _AdamHandler(_OptimizerHandler): - """Handles Adam specific logic.""" - - def __init__(self, optimization_parameters): - super(_AdamHandler, self).__init__(optimization_parameters) - self._table_to_m_variables_dict = {} - self._table_to_v_variables_dict = {} - - def set_optimization_parameters(self, table_descriptor): - table_descriptor.optimization_parameters.adam.beta1 = ( - self._optimization_parameters.beta1) - table_descriptor.optimization_parameters.adam.beta2 = ( - self._optimization_parameters.beta2) - table_descriptor.optimization_parameters.adam.epsilon = ( - self._optimization_parameters.epsilon) - table_descriptor.optimization_parameters.adam.use_non_lazy_adam = ( - not self._optimization_parameters.lazy_adam) - table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = ( - self._optimization_parameters.sum_inside_sqrt) - - def create_variables_and_ops(self, table, variable_name, num_hosts, - table_config, table_variables, - load_parameters_ops, retrieve_parameters_ops): - optimizer_name = 'Adam' - m_initializer = init_ops.zeros_initializer() - m_variables = _create_partitioned_variables( - name='%s/%s/m' % (variable_name, optimizer_name), - num_hosts=num_hosts, - vocabulary_size=table_config.vocabulary_size, - embedding_dimension=table_config.dimension, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - initializer=m_initializer) - v_initializer = init_ops.zeros_initializer() - v_variables = _create_partitioned_variables( - name='%s/%s/v' % (variable_name, optimizer_name), - num_hosts=num_hosts, - vocabulary_size=table_config.vocabulary_size, - embedding_dimension=table_config.dimension, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - initializer=v_initializer) - - self._table_to_m_variables_dict[table] = m_variables - self._table_to_v_variables_dict[table] = v_variables - - for host_id, table_variable, m_variable, v_variable in (zip( - range(num_hosts), table_variables, - m_variables, v_variables)): - with ops.colocate_with(table_variable): - load_parameters_op = ( - tpu_ops.load_tpu_embedding_adam_parameters( - parameters=table_variable, - momenta=m_variable, - velocities=v_variable, - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieved_table, retrieved_m, retrieved_v = ( - tpu_ops.retrieve_tpu_embedding_adam_parameters( - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieve_parameters_op = control_flow_ops.group( - state_ops.assign(table_variable, retrieved_table), - state_ops.assign(m_variable, retrieved_m), - state_ops.assign(v_variable, retrieved_v)) - - load_parameters_ops.append(load_parameters_op) - retrieve_parameters_ops.append(retrieve_parameters_op) - - def get_slot_names(self): - return ['m', 'v'] - - def get_slot(self, table, name): - if name == 'm': - return self._table_to_m_variables_dict[table] - elif name == 'v': - return self._table_to_v_variables_dict[table] - else: - raise ValueError('Adam has {} as slot names; got {}.' - .format(self.get_slot_names(), name)) - - -class _StochasticGradientDescentHandler(_OptimizerHandler): - """Handles stochastic gradient descent specific logic.""" - - def set_optimization_parameters(self, table_descriptor): - (table_descriptor.optimization_parameters.stochastic_gradient_descent - .SetInParent()) - - def create_variables_and_ops(self, table, variable_name, num_hosts, - table_config, table_variables, - load_parameters_ops, retrieve_parameters_ops): - del table_config - - for host_id, table_variable in (zip( - range(num_hosts), table_variables)): - with ops.colocate_with(table_variable): - load_parameters_op = ( - tpu_ops - .load_tpu_embedding_stochastic_gradient_descent_parameters( - parameters=table_variable, - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieved_table = ( - tpu_ops - .retrieve_tpu_embedding_stochastic_gradient_descent_parameters( - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieve_parameters_op = control_flow_ops.group( - state_ops.assign(table_variable, retrieved_table)) - - load_parameters_ops.append(load_parameters_op) - retrieve_parameters_ops.append(retrieve_parameters_op) - - def get_slot_names(self): - return [] - - def get_slot(self, table, name): - raise ValueError('Stochastic gradient descent does not have slot variable.') - - -def _get_optimization_handler(optimization_parameters): - if isinstance(optimization_parameters, AdagradParameters): - return _AdagradHandler(optimization_parameters) - elif isinstance(optimization_parameters, AdamParameters): - return _AdamHandler(optimization_parameters) - elif isinstance(optimization_parameters, StochasticGradientDescentParameters): - return _StochasticGradientDescentHandler(optimization_parameters) - else: - return NotImplementedError() - - -def _create_ordered_dict(d): - """Create an OrderedDict from Dict.""" - return collections.OrderedDict((k, d[k]) for k in sorted(d)) - - -def _create_combiners(table_to_config_dict): - return [table_to_config_dict[t].combiner for t in table_to_config_dict] - - -def _create_table_to_features_dict(feature_to_table_dict): - """Create mapping from table to a list of its features.""" - table_to_features_dict_tmp = {} - for feature, table in six.iteritems(feature_to_table_dict): - if table in table_to_features_dict_tmp: - table_to_features_dict_tmp[table].append(feature) - else: - table_to_features_dict_tmp[table] = [feature] - - table_to_features_dict = collections.OrderedDict() - for table in sorted(table_to_features_dict_tmp): - table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table]) - return table_to_features_dict - - -def _create_device_fn(hosts): - """Create device_fn() to use with _create_partitioned_variables().""" - - def device_fn(op): - """Returns the `device` for `op`.""" - part_match = re.match(r'.*/part_(\d+)(/|$)', op.name) - - if part_match: - idx = int(part_match.group(1)) - else: - raise RuntimeError('Internal Error: ' - 'Expected %s to contain /part_*.' % op.name) - - device = hosts[idx] - return device - - return device_fn - - -def _create_partitioned_variables(name, - num_hosts, - vocabulary_size, - embedding_dimension, - initializer, - collections=None): # pylint: disable=redefined-outer-name - """Creates ParitionedVariables based on `num_hosts` for `table`.""" - # TODO(shizhiw): automatically place embedding lookup elsewhere? - if vocabulary_size < num_hosts: - raise ValueError('`vocabulary_size`({}) is smaller than `num_hosts`({}). ' - 'As TPU embedding is not optimized for small tables, ' - 'please consider other ways for this embedding lookup.') - - return list(variable_scope.get_variable( - name, - shape=(vocabulary_size, embedding_dimension), - partitioner=partitioned_variables.fixed_size_partitioner(num_hosts), - dtype=dtypes.float32, - initializer=initializer, - collections=collections, - trainable=False)) - - -@ops.RegisterGradient('TPUEmbeddingActivations') -def _embedding_activations_grad(activations_op, grad_wrt_activations): - """Saves the gradient of embedding activations ops in a graph collection.""" - g = ops.get_default_graph() - table_id = activations_op.get_attr('table_id') - lookup_id = activations_op.get_attr('lookup_id') - table_gradients = g.get_collection_ref( - 'tpu_embedding_gradients_table_%d' % table_id) - - if not table_gradients: - raise RuntimeError( - 'Gradients for TPUEmbedding have been generated in non-training mode. ' - 'This is not expected. Consider putting your Optimizer.minimize code ' - 'behind the training mode condition check. For Estimator, you can ' - 'do \n\n' - ' if mode == tf.estimator.ModeKeys.TRAIN:\n' - ' train_op = opt.minimize(loss)\n' - '\n') - - table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations) - return [ - # RegisterGradient requires that value be returned for all inputs. Since - # the first argument (tpu_gradient_variable_{table_name}) has shape [1], - # we will return zeros(shape=[1]). The actual gradient w.r.t. the - # embedding activations (grad_wrt_activations) has the same shape as the - # activations returned by embedding_activations. - array_ops.zeros(arg.shape, dtype=dtypes.float32) - for arg in activations_op.inputs - ] +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_embedding import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding_gradient.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding_gradient.py new file mode 100644 index 0000000000000000000000000000000000000000..308adc77e9ad2d912d0461512655b55faa53da60 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding_gradient.py @@ -0,0 +1,23 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_embedding_gradient import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 96b9556e137effcaaa5916b9723142f737a6dc33..893118412e1363ce50416e6ef36692bc23d04179 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -1,3468 +1,33 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""TPUEstimator class.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import copy -import os -import signal -import sys -import threading -import time - -import numpy as np -import six -from six.moves import queue as Queue # pylint: disable=redefined-builtin -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.contrib.tpu.python.tpu import tensor_tracer -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import error_handling -from tensorflow.contrib.tpu.python.tpu import session_support -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.contrib.tpu.python.tpu import tpu_config -from tensorflow.contrib.tpu.python.tpu import tpu_context -from tensorflow.contrib.tpu.python.tpu import tpu_feed -from tensorflow.contrib.tpu.python.tpu import training_loop -from tensorflow.contrib.tpu.python.tpu import util as util_lib -from tensorflow.contrib.training.python.training import hparam -from tensorflow.core.framework import variable_pb2 -from tensorflow.core.framework.summary_pb2 import Summary -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session as tf_session -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest as data_nest -from tensorflow.python.estimator import estimator as estimator_lib -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.export import export_output as export_output_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_ops_v2 as contrib_summary -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.saved_model import tag_constants -from tensorflow.python.summary import summary -from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import evaluation -from tensorflow.python.training import session_run_hook -from tensorflow.python.training import training -from tensorflow.python.training import training_util -from tensorflow.python.util import function_utils -from tensorflow.python.util import nest -from tensorflow.python.util import tf_inspect - -_INITIAL_LOSS = 1e7 -_ZERO_LOSS = 0. -_TPU_ESTIMATOR = 'tpu_estimator' -_ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' -_BATCH_SIZE_KEY = 'batch_size' -_CTX_KEY = 'context' -_USE_TPU_KEY = 'use_tpu' -_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' -_ONE_GIGABYTE = 1024 * 1024 * 1024 -_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' -_TPU_TRAIN_OP = '_tpu_train_op' -_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' - -# Ideally _USE_TPU_KEY should be reserved as well. However there are already -# models that make use of this key, thus it can not be reserved now to prevent -# breakage. In the long run, we would like to mitigate this by migrating models -# off of using _USE_TPU_KEY. -_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] - -# TODO(b/65703635): Flip the value and remove all dead code. Currently, this is -# only used for per-core based deployments. For per-host based pipelines, if a -# user returns a Dataset instance it will be automatically wrapped in a -# tf.while_loop (This can be disabled by returning features and labels -# explicitly). -_WRAP_INPUT_FN_INTO_WHILE_LOOP = False - -ops.register_proto_function( - '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR), - proto_type=variable_pb2.VariableDef, - to_proto=resource_variable_ops._to_proto_fn, # pylint: disable=protected-access - from_proto=resource_variable_ops._from_proto_fn) # pylint: disable=protected-access - - -def _is_iterable(obj): - """A Python 2 and 3 compatible util to check whether `obj` is iterable.""" - try: - iter(obj) - return True - except TypeError: - return False - - -def _create_global_step(graph): - graph = graph or ops.get_default_graph() - if training.get_global_step(graph) is not None: - raise ValueError('"global_step" already exists.') - # Create in proper graph and base name_scope. - with graph.as_default() as g, g.name_scope(None): - return variable_scope.get_variable( - ops.GraphKeys.GLOBAL_STEP, - shape=[], - dtype=dtypes.int64, - initializer=init_ops.zeros_initializer(), - trainable=False, - use_resource=True, - collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) - - -def _create_or_get_iterations_per_loop(): - """Creates or gets the iterations_per_loop variable. - - In TPUEstimator, the user provided computation, the model_fn, is wrapped - inside a tf.while_loop for peak performance. The iterations of the loop are - specified by this variable, which adjusts its value on the CPU after each TPU - program execution and before the next TPU execution. - - The purpose of using a variable, rather then a constant, is to allow - TPUEstimator adapt the TPU training iterations according to the final steps - specified by users. For example, if the user sets the iterations_per_loop as 4 - in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop - variable will have the following value before each TPU training. - - - 1-th TPU execution: iterations_per_loop = 4 - - 2-th TPU execution: iterations_per_loop = 4 - - 3-th TPU execution: iterations_per_loop = 2 - - As model_fn increases the global step once per train_op invocation, the global - step is 10 after all TPU executions, matching the steps=10 inputs passed in by - users. - - Returns: - A TF non-trainable resource variable. - - Raises: - RuntimeError: If multi iterations_per_loop variables were found. - """ - graph = ops.get_default_graph() - collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR) - iter_vars = graph.get_collection(collection_name) - if len(iter_vars) == 1: - return iter_vars[0] - elif len(iter_vars) > 1: - raise RuntimeError('Multiple iterations_per_loop_var in collection.') - - with ops.colocate_with(training_util.get_global_step()): - with variable_scope.variable_scope( - _TPU_ESTIMATOR, reuse=variable_scope.AUTO_REUSE): - return variable_scope.get_variable( - _ITERATIONS_PER_LOOP_VAR, - initializer=init_ops.zeros_initializer(), - shape=[], - dtype=dtypes.int32, - trainable=False, - collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES], - use_resource=True) - - -def _sync_variables_ops(ctx): - """Create varriables synchronization ops. - - Gets the variables back from TPU nodes. This means the variables updated - by TPU will now be *synced* to host memory. - In BROADCAST mode, we skip this sync since the variables are ususally too - big to transmit via RPC. - - Args: - ctx: A `_InternalTPUContext` instance with mode. - - Returns: - A list of sync ops. - """ - - if not ctx.is_input_broadcast_with_iterators(): - return [ - array_ops.check_numerics(v.read_value(), - 'Gradient for %s is NaN' % v.name).op - for v in variables.trainable_variables() - ] - else: - return [control_flow_ops.no_op()] - - -def _increase_eval_step_op(iterations_per_loop): - """Returns an op to increase the eval step for TPU evaluation. - - Args: - iterations_per_loop: Tensor. The number of eval steps running in TPU system - before returning to CPU host for each `Session.run`. - - Returns: - An operation - """ - eval_step = evaluation._get_or_create_eval_step() # pylint: disable=protected-access - # Estimator evaluate increases 1 by default. So, we increase the difference. - return state_ops.assign_add( - eval_step, - math_ops.cast(iterations_per_loop - 1, dtype=eval_step.dtype), - use_locking=True) - - -def _extract_key_names(tensor_or_dict): - if isinstance(tensor_or_dict, dict): - return sorted(tensor_or_dict.keys()) - return [] - - -class _SIGNAL(object): - """Signal used to control the thread of infeed/outfeed. - - All preserved signals must be negative numbers. Positive numbers are used to - indicate the number of iterations for next training/evaluation loop. - """ - NEXT_BATCH = -1 - STOP = -2 - - -class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. - - See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and - `export_outputs`. - - For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where - `metric_fn` runs on CPU to generate metrics and `tensors` represents the - `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`. - To be precise, TPU evaluation expects a slightly different signature from the - `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a - dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`. - The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The - `tensors` usually specify the model logits, which are transferred back from - TPU system to CPU host. All tensors must have be batch-major, i.e., the batch - size is the first dimension. Once all tensors are available at CPU host from - all shards, they are concatenated (on CPU) and passed as positional arguments - to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is - a dict. `metric_fn` takes the `tensors` and returns a dict from metric string - name to the result of calling a metric function, namely a `(metric_tensor, - update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the - `eval_metrics`. - - `scaffold_fn` is a function running on CPU to generate the `Scaffold`. This - function should not capture any Tensors in `model_fn`. - - `host_call` is a tuple of a `function` and a list or dictionary of `tensors` - to pass to that function and returns a list of Tensors. `host_call` currently - works for train() and evaluate(). The Tensors returned by the function is - executed on the CPU on every step, so there is communication overhead when - sending tensors from TPU to CPU. To reduce the overhead, try reducing the - size of the tensors. The `tensors` are concatenated along their major (batch) - dimension, and so must be >= rank 1. The `host_call` is useful for writing - summaries with `tf.contrib.summary.create_file_writer`. - """ - - def __new__(cls, - mode, - predictions=None, - loss=None, - train_op=None, - eval_metrics=None, - export_outputs=None, - scaffold_fn=None, - host_call=None, - training_hooks=None, - evaluation_hooks=None, - prediction_hooks=None): - """Creates a validated `TPUEstimatorSpec` instance.""" - host_calls = {} - if eval_metrics is not None: - host_calls['eval_metrics'] = eval_metrics - if host_call is not None: - host_calls['host_call'] = host_call - _OutfeedHostCall.validate(host_calls) - - training_hooks = tuple(training_hooks or []) - evaluation_hooks = tuple(evaluation_hooks or []) - prediction_hooks = tuple(prediction_hooks or []) - - for hook in training_hooks + evaluation_hooks + prediction_hooks: - if not isinstance(hook, session_run_hook.SessionRunHook): - raise TypeError('All hooks must be SessionRunHook instances, given: {}' - .format(hook)) - - return super(TPUEstimatorSpec, cls).__new__( - cls, - mode=mode, - predictions=predictions, - loss=loss, - train_op=train_op, - eval_metrics=eval_metrics, - export_outputs=export_outputs, - scaffold_fn=scaffold_fn, - host_call=host_call, - training_hooks=training_hooks, - evaluation_hooks=evaluation_hooks, - prediction_hooks=prediction_hooks) - - def as_estimator_spec(self): - """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" - host_calls = {} - if self.eval_metrics is not None: - host_calls['eval_metrics'] = self.eval_metrics - if self.host_call is not None: - host_calls['host_call'] = self.host_call - host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls) - eval_metric_ops = None - if self.eval_metrics is not None: - eval_metric_ops = host_call_ret['eval_metrics'] - hooks = None - if self.host_call is not None: - hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] - hooks = tuple(hooks or []) - scaffold = self.scaffold_fn() if self.scaffold_fn else None - return model_fn_lib.EstimatorSpec( - mode=self.mode, - predictions=self.predictions, - loss=self.loss, - train_op=self.train_op, - eval_metric_ops=eval_metric_ops, - export_outputs=self.export_outputs, - scaffold=scaffold, - training_hooks=self.training_hooks + hooks, - evaluation_hooks=self.evaluation_hooks + hooks, - prediction_hooks=self.prediction_hooks + hooks) - - -class _OpQueueContext(object): - """Manages work queue and thread for a infeed/outfeed thread.""" - - def __init__(self, name, target, args): - self._name = name - self._queue = Queue.Queue() - args = (self,) + args - self._thread = threading.Thread(name=name, target=target, args=args) - self._thread.daemon = True - self._thread.start() - - def stop(self): - self._queue.put(_SIGNAL.STOP) - - def send_next_batch_signal(self, iterations): - self._queue.put(iterations) - - def read_iteration_counts(self): - while True: - iterations = self._queue.get(block=True) - logging.debug('%s read iterations %s', self._name, iterations) - if iterations == _SIGNAL.STOP: - logging.info('%s received shutdown signal, stopping.', self._name) - return - yield iterations - - def join(self): - logging.info('Shutting down %s thread.', self._name) - self.stop() - self._thread.join() - - -class _OpSignalOnceQueueContext(_OpQueueContext): - """Manages work queue and thread for a infeed/outfeed thread. - - This subclass only signals once. - """ - - def __init__(self, name, target, args): - super(_OpSignalOnceQueueContext, self).__init__(name, target, args) - self._has_signaled = False - - def send_next_batch_signal(self, iterations): - if not self._has_signaled: - self._queue.put(iterations) - self._has_signaled = True - - -class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): - """A Session hook setting up the TPU initialization, infeed, and outfeed. - - This hook does two major things: - 1. initialize and shutdown TPU system. - 2. launch and join the threads for infeed enqueue and (optional) outfeed - dequeue. - """ - - def __init__(self, - ctx, - enqueue_ops, - dequeue_ops, - run_infeed_loop_on_coordinator=True, - rendezvous=None, - master=None, - session_config=None): - self._master_job = ctx.master_job - self._enqueue_ops = enqueue_ops - self._dequeue_ops = dequeue_ops - self._rendezvous = rendezvous - self._master = master - self._session_config = session_config - self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator - self._initial_infeed_sleep_secs = ( - ctx.config.tpu_config.initial_infeed_sleep_secs) - - self._feed_error = None - self._finished = False - self._should_initialize_tpu = True - - def begin(self): - logging.info('TPU job name %s', self._master_job) - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - self._init_ops = [] - if self._should_initialize_tpu: - self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] - else: - self._finalize_ops = [] - - summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() - self._init_ops.extend(summary_writer_init_ops) - # Get all the writer resources from the initializer, so we know what to - # flush. - for op in summary_writer_init_ops: - self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) - - def _run_infeed(self, queue_ctx, session): - logging.info('Starting infeed thread controller.') - if self._initial_infeed_sleep_secs: - logging.info('Infeed thread sleeping for %d seconds.', - self._initial_infeed_sleep_secs) - time.sleep(self._initial_infeed_sleep_secs) - logging.info('Infeed thread starting after sleep') - - with self._rendezvous.catch_errors(source='infeed', session=session): - if self._run_infeed_loop_on_coordinator: - for count, steps in enumerate(queue_ctx.read_iteration_counts()): - for i in xrange(steps): - logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) - session.run(self._enqueue_ops) - else: - for _ in queue_ctx.read_iteration_counts(): - session.run(self._enqueue_ops) - logging.info('Infeed thread finished, shutting down.') - - def _run_outfeed(self, queue_ctx, session): - logging.info('Starting outfeed thread controller.') - with self._rendezvous.catch_errors(source='outfeed', session=session): - for count, steps in enumerate(queue_ctx.read_iteration_counts()): - for i in xrange(steps): - logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i) - session.run(self._dequeue_ops) - logging.info('Outfeed thread finished, shutting down.') - - def _create_infeed_controller(self, name, target, args): - return _OpQueueContext(name=name, target=target, args=args) - - def after_create_session(self, session, coord): - if self._should_initialize_tpu: - logging.info('Init TPU system') - start = time.time() - with ops.Graph().as_default(): - with tf_session.Session( - self._master, config=self._session_config) as sess: - sess.run(tpu.initialize_system(job=self._master_job)) - logging.info('Initialized TPU in %d seconds', time.time() - start) - - session.run(self._init_ops, - options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) - - self._infeed_controller = self._create_infeed_controller( - name='InfeedController', target=self._run_infeed, args=(session,)) - - self._outfeed_controller = _OpQueueContext( - name='OutfeedController', target=self._run_outfeed, args=(session,)) - - # Enable the worker watchdog to terminate workers on coordinator exit. - watchdog_timeout = int(os.environ.get('TF_TPU_WATCHDOG_TIMEOUT', '0')) - if watchdog_timeout > 0: - session_support.start_worker_watchdog(session, - shutdown_timeout=watchdog_timeout) - - def before_run(self, run_context): - self._feed_error = None - - iterations = run_context.session.run(self._iterations_per_loop_var) - - logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) - self._infeed_controller.send_next_batch_signal(iterations) - - logging.info('Dequeue next (%d) batch(es) of data from outfeed.', - iterations) - self._outfeed_controller.send_next_batch_signal(iterations) - - def end(self, session): - self._finished = True - logging.info('Stop infeed thread controller') - self._infeed_controller.join() - self._rendezvous.record_done('infeed') - - logging.info('Stop output thread controller') - self._outfeed_controller.join() - self._rendezvous.record_done('outfeed') - - logging.info('Shutdown TPU system.') - session.run(self._finalize_ops) - - -class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): - - def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None, - master=None, session_config=None): - super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( - ctx, - enqueue_ops, - dequeue_ops, - run_infeed_loop_on_coordinator=False, - rendezvous=rendezvous, - master=master, - session_config=session_config) - - def _create_infeed_controller(self, name, target, args): - return _OpSignalOnceQueueContext(name=name, target=target, args=args) - - -class _TPUStopAtStepHook(session_run_hook.SessionRunHook): - """Hook that requests stop at a specified step. - - This hook is similar to the `session_run_hook._StopAfterNEvalsHook` with - following differences for TPU training: - - 1. This hook sets the variable for iterations_per_loop, which is used by - `TPUInfeedOutfeedSessionHook` to control the iterations for infeed/outfeed. - As the hook execution order is not guaranteed, the variable update is - handled in `after_create_session` and `after_run` as - `TPUInfeedOutfeedSessionHook` reads the variable value in `before_run`. - - 2. For each training loop (session.run), the global step could be increased - multiple times on TPU. The global step tensor value will be explicitly read - again in `after_run` to ensure the latest value is retrieved to avoid race - condition. - """ - - def __init__(self, iterations, num_steps=None, last_step=None): - """Initializes a `StopAtStepHook`. - - Args: - iterations: The number of iterations to run optimizer per training loop. - num_steps: Number of steps to execute. - last_step: Step after which to stop. - - Raises: - ValueError: If one of the arguments is invalid. - """ - if num_steps is None and last_step is None: - raise ValueError('One of num_steps or last_step must be specified.') - if num_steps is not None and last_step is not None: - raise ValueError('Only one of num_steps or last_step can be specified.') - self._num_steps = num_steps - self._last_step = last_step - self._iterations = iterations - - def _next_iterations(self, global_step, last_step): - gap = last_step - global_step - return min(gap, self._iterations) - - def begin(self): - self._global_step_tensor = training_util.get_global_step() - if self._global_step_tensor is None: - raise RuntimeError('Global step should be created.') - - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - - def after_create_session(self, session, coord): - global_step = session.run(self._global_step_tensor) - if self._last_step is None: - self._last_step = global_step + self._num_steps - - iterations = self._next_iterations(global_step, self._last_step) - - self._iterations_per_loop_var.load(iterations, session=session) - - def after_run(self, run_context, run_values): - # Global step cannot be retrieved via SessionRunArgs and before_run due to - # race condition. - global_step = run_context.session.run(self._global_step_tensor) - if global_step >= self._last_step: - run_context.request_stop() - else: - iterations = self._next_iterations(global_step, self._last_step) - self._iterations_per_loop_var.load( - iterations, session=run_context.session) - - -class _SetEvalIterationsHook(session_run_hook.SessionRunHook): - """Hook that requests stop at a specified step.""" - - def __init__(self, num_steps): - """Initializes a `_SetEvalIterationsHook`. - - Args: - num_steps: Number of steps to execute. - """ - self._num_steps = num_steps - - def begin(self): - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - - def after_create_session(self, session, coord): - self._iterations_per_loop_var.load(self._num_steps, session=session) - - -class _StoppingPredictHook(session_run_hook.SessionRunHook): - """Hook that requests stop according to the stopping signal in prediction.""" - - def __init__(self, scalar_stopping_signal): - self._scalar_stopping_signal = scalar_stopping_signal - - def begin(self): - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - - def after_create_session(self, session, coord): - # This is not necessary as we do not run infeed enqueue and outfeed dequeue - # in side threads for prediction model. But it makes the - # TPUInfeedOutfeedSessionHook prints nice message. - self._iterations_per_loop_var.load(1, session=session) - - def before_run(self, run_context): - return session_run_hook.SessionRunArgs(self._scalar_stopping_signal) - - def after_run(self, run_context, run_values): - _ = run_context - scalar_stopping_signal = run_values.results - if _StopSignals.should_stop(scalar_stopping_signal): - # NOTE(xiejw): In prediction, stopping signals are inserted for each - # batch. And we append one more batch to signal the system it should stop. - # The data flow might look like - # - # batch 0: images, labels, stop = 0 (user provided) - # batch 1: images, labels, stop = 0 (user provided) - # ... - # batch 99: images, labels, stop = 0 (user provided) - # batch 100: images, labels, stop = 1 (TPUEstimator appended) - # - # where the final batch (id = 100) is appended by TPUEstimator, so we - # should drop it before returning the predictions to user. - # To achieve that, we throw the OutOfRangeError in after_run. Once - # Monitored Session sees this error in SessionRunHook.after_run, the - # "current" prediction, i.e., batch with id=100, will be discarded - # immediately - raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.') - - -def generate_per_core_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder, host_device, host_id): - """Generates infeed enqueue ops for per-core input_fn on a single host.""" - captured_infeed_queue = _CapturedObject() - tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) - - def enqueue_ops_fn(): - """A fn returns enqueue_ops.""" - num_cores_per_host = ctx.num_of_cores_per_host - per_host_sharded_inputs = [] - for core_ordinal in range(num_cores_per_host): - with ops.name_scope('ordinal_%d' % (core_ordinal)): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, - input_device=host_device, - invocation_index=host_id * ctx.num_of_cores_per_host + core_ordinal) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - if inputs.is_dataset: - raise TypeError( - '`input_fn` returning `Dataset` is not yet supported in ' - 'per-Core input pipeline deployment yet. Please set ' - 'TPUConfig.per_host_input_for_training to True or return ' - '`features` and `labels` from `input_fn`') - features, labels = inputs.features_and_labels() - - inputs_structure_recorder.validate_and_record_structure( - features, labels) - flattened_inputs = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels)) - per_host_sharded_inputs.append(flattened_inputs) - - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0])) - captured_infeed_queue.capture(infeed_queue) - - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) - return per_host_enqueue_ops - - return enqueue_ops_fn, captured_infeed_queue - - -def generate_per_host_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder, batch_axis, device, host_id): - """Generates infeed enqueue ops for per-host input_fn on a single host.""" - captured_infeed_queue = _CapturedObject() - - dataset_initializer = None - - with ops.device(device): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, input_device=device, invocation_index=host_id) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - - is_dataset = inputs.is_dataset - if ctx.mode == model_fn_lib.ModeKeys.PREDICT: - if not is_dataset: - raise TypeError( - 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' - '`features` and `labels`.') - if batch_axis is not None: - raise TypeError('For mode PREDICT, batch_axis is not supported yet.') - inputs = _InputsWithStoppingSignals( - dataset=inputs.dataset, - batch_size=ctx.batch_size_for_input_fn, - add_padding=True) - - if is_dataset: - dataset_initializer = inputs.dataset_initializer() - - tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) - - def enqueue_ops_fn(): - """A Fn returning the TPU infeed enqueue ops. - - By providing as a Fn, it can be invoked inside the tf.while_loop such that - the input pipeline for multiple iterations can be executed by one - Session.run call. - - Returns: - list of dict of ops. - """ - with ops.device(device): - num_of_replicas_per_host = ctx.num_of_replicas_per_host - # Convert user input to features and labels. If the user returns a - # dataset, it is initialized and the features and labels extracted via - # `dataset.iterator.get_next()` - features, labels = inputs.features_and_labels() - signals = inputs.signals() - - inputs_structure_recorder.validate_and_record_structure(features, labels) - unsharded_tensor_list = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels, signals)) - - infeed_queue = tpu_feed.InfeedQueue( - tuple_types=[t.dtype for t in unsharded_tensor_list], - tuple_shapes=[t.shape for t in unsharded_tensor_list], - shard_dimensions=batch_axis) - captured_infeed_queue.capture(infeed_queue) - infeed_queue.set_number_of_shards(num_of_replicas_per_host) - per_host_enqueue_ops = ( - infeed_queue.split_inputs_and_generate_enqueue_ops( - unsharded_tensor_list, - placement_function=lambda x: device, - tpu_ordinal_function=tpu_ordinal_function_impl)) - if signals is None: - return per_host_enqueue_ops - else: - return { - 'ops': per_host_enqueue_ops, - 'signals': signals, - } - - return enqueue_ops_fn, captured_infeed_queue, dataset_initializer - - -def generate_per_host_v2_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder, device, host_id): - """Generates infeed enqueue ops for per-host input_fn on a single host.""" - captured_infeed_queue = _CapturedObject() - dataset_initializer = None - - with ops.device(device): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, input_device=device, invocation_index=host_id) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - - is_dataset = inputs.is_dataset - if not is_dataset: - raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 ' - 'input pipeline configuration.') - - if ctx.mode == model_fn_lib.ModeKeys.PREDICT: - inputs = _InputsWithStoppingSignals( - dataset=inputs.dataset, - batch_size=ctx.batch_size_for_input_fn, - add_padding=True, - num_invocations_per_step=ctx.num_of_replicas_per_host) - - dataset_initializer = inputs.dataset_initializer() - tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) - - def enqueue_ops_fn(): - """Generates the per_host enqueue ops.""" - control_deps = [] - per_host_sharded_inputs = [] - num_replicas_per_host = ctx.num_of_replicas_per_host - cached_signals = None - with ops.device(device): - if not inputs.is_dataset: - raise TypeError('`input_fn` must return a `Dataset` for this mode.') - for _ in range(num_replicas_per_host): - # Use control dependencies to ensure a deterministic ordering. - with ops.control_dependencies(control_deps): - features, labels = inputs.features_and_labels() # Calls get_next() - signals = inputs.signals() - - # All the replicas share the replica 0's stopping singal. - # This avoids inconsistent state among different model replcias. - if cached_signals: - signals['stopping'] = cached_signals['stopping'] - else: - cached_signals = signals - - inputs_structure_recorder.validate_and_record_structure( - features, labels) - flattened_inputs = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels, signals)) - control_deps.extend(flattened_inputs) - per_host_sharded_inputs.append(flattened_inputs) - - if inputs_structure_recorder.flattened_input_dims: - input_partition_dims = inputs_structure_recorder.flattened_input_dims - if signals: - input_partition_dims += [None] * len(signals) - # pylint: disable=protected-access - infeed_queue = tpu_feed._PartitionedInfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0]), - host_id=host_id, - input_partition_dims=input_partition_dims, - device_assignment=ctx.device_assignment) - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs) - else: - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0])) - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, - tpu_ordinal_function=tpu_ordinal_function_impl) - captured_infeed_queue.capture(infeed_queue) - - if signals is None: - return per_host_enqueue_ops - else: - return { - 'ops': per_host_enqueue_ops, - 'signals': signals, - } - - return enqueue_ops_fn, captured_infeed_queue, dataset_initializer - - -def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, - num_hosts): - """Generates infeed enqueue ops for one input_fn on all the hosts.""" - captured_infeed_queue = _CapturedObject() - dataset_initializer = None - device_0 = ctx.tpu_host_placement_function(host_id=0) - with ops.device(device_0): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, input_device=device_0, invocation_index=0) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - - is_dataset = inputs.is_dataset - if ctx.mode == model_fn_lib.ModeKeys.PREDICT: - if not is_dataset: - raise TypeError( - 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' - '`features` and `labels`.') - - inputs = _InputsWithStoppingSignals( - dataset=inputs.dataset, - batch_size=ctx.batch_size_for_input_fn, - add_padding=True) - - if is_dataset: - dataset_initializer = inputs.dataset_initializer() - num_replicas_per_host = ctx.num_of_replicas_per_host - - def tpu_ordinal_function_impl(replica_id): - if ctx.device_assignment: - return ctx.device_assignment.tpu_ordinal(replica=replica_id) - else: - return replica_id % num_replicas_per_host - - def device_function_impl(replica_id): - return ctx.tpu_host_placement_function(replica_id=replica_id) - - def enqueue_ops_fn(): - """Generates enqueue ops for all the hosts.""" - broadcasted_inputs = [] - flattened_inputs = None # Cache result from input_fn. - signals = None - for host_id in xrange(num_hosts): - with ops.device(ctx.tpu_host_placement_function(host_id=host_id)): - for _ in xrange(ctx.num_of_replicas_per_host): - # Note: input_fn is only called once at host 0 for the first replica. - # The features and labels returned from that invocation are - # broadcasted to other replicas(including the replicas on other - # hosts). - if flattened_inputs is None: - features, labels = inputs.features_and_labels() # Calls get_next() - signals = inputs.signals() - - inputs_structure_recorder.validate_and_record_structure( - features, labels) - flattened_inputs = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels, signals)) - broadcasted_inputs.append(flattened_inputs) - - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(broadcasted_inputs[0])) - captured_infeed_queue.capture(infeed_queue) - enqueue_ops = infeed_queue.generate_enqueue_ops( - broadcasted_inputs, - tpu_ordinal_function=tpu_ordinal_function_impl, - placement_function=device_function_impl) - - if signals is None: - return enqueue_ops - else: - return { - 'ops': enqueue_ops, - 'signals': signals, - } - - return enqueue_ops_fn, captured_infeed_queue, dataset_initializer - - -class _InputPipeline(object): - """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. - - `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from - call site. To be precise, based on the configuration in - `_InternalTPUContext`, it invokes `input_fn` for all cores (usually - multi-host TPU training) or for one host (usually for single-host TPU - evaluation), and sends all `features` and `labels` returned by `input_fn` to - TPU infeed. For per-core invocation, `features` and `labels` are piped to - infeed directly, one tuple for each core. For per-host invocation, `features` - and `labels` are split at host (with respect to `batch_axis`) and piped to all - cores accordingly. - - In addition, flatten/unflatten are handled by `_InputPipeline` also. Model - inputs returned by the `input_fn` can have one of the following forms: - 1. features - 2. (features, labels) - 3. ((arbitrarily nested structure of features), labels) - - Internally, form 1 is reformed to `(features, None)` as features and labels - are passed separately to underlying methods. For TPU training, TPUEstimator - may expect multiple `features` and `labels` tuples one for each core. - - TPUEstimator allows various different structures for inputs (namely `features` - and `labels`). Both `features` and `labels` can be any nested sturcture - supported by TF nest (namely, dict, tuples, namedtuples or any nested - structure of such of Tensors). `labels` could be `None` as well. - - These are flattened before they are passed to the infeed/outfeed library - as that expectes flattend lists. - """ - - class InputsStructureRecorder(object): - """The recorder to record inputs structure.""" - - def __init__(self, input_partition_dims=None): - # Holds the structure of inputs - self._feature_structure = {} - self._flattened_input_dims = None - - if input_partition_dims: - # This should have been validated in TPUConfig. - assert len(input_partition_dims) <= 2, 'must have 1 or 2 elements.' - if len(input_partition_dims) == 2: - self._feature_dims, self._label_dims = input_partition_dims - else: - self._feature_dims = input_partition_dims[0] - self._label_dims = None - - assert self._feature_dims is not None, ('input_partition_dims[0] must ' - 'not be None') - else: - self._feature_dims = None - self._label_dims = None - - # Internal state. - self._initialized = False - - @property - def flattened_input_dims(self): - assert self._initialized, 'InputsStructureRecorder is not initialized.' - return self._flattened_input_dims - - def has_labels(self): - return 'labels' in self._feature_structure - - def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims, - label_dims_names, label_names, has_labels): - """Flatten input dims with the same order as flattened input tensors.""" - flattened_input_dims = [] - if feature_dims_names: - # We need a fixed ordering for matching the tensors in features. - flattened_input_dims.extend( - [feature_dims[name] for name in feature_dims_names]) - else: - flattened_input_dims.append(feature_dims) - - if label_dims_names: - # We need a fixed ordering for matching the tensors in labels. - flattened_input_dims.extend( - [label_dims[name] for name in label_dims_names]) - else: - if label_names: - num_tensors_in_label = len(label_names) - else: - num_tensors_in_label = int(has_labels) - # Setting `None` in input_partition_dims[1] will apply `None` to - # all the tensors in labels, regardless of internal structure. - flattened_input_dims.extend([label_dims] * num_tensors_in_label) - - return flattened_input_dims - - def validate_and_record_structure(self, features, labels): - """Validates and records the structure of `features` and `labels`.""" - # Extract structure. - has_labels = labels is not None - feature_names = _extract_key_names(features) - label_names = _extract_key_names(labels) - - if not self._initialized: - # Record structure. - self._initialized = True - if self._feature_dims is not None: - feature_dims_names = _extract_key_names(self._feature_dims) - if feature_dims_names != feature_names: - raise ValueError( - 'TPUConfig.input_partition_dims[0] mismatched feature' - ' keys. Expected {}, got {}'.format(feature_names, - feature_dims_names)) - - label_dims_names = _extract_key_names(self._label_dims) - if self._label_dims is not None and label_dims_names != label_names: - raise ValueError( - 'TPUConfig.input_partition_dims[1] mismatched label' - ' keys. Expected {}, got {}'.format(label_names, - label_dims_names)) - - self._flattened_input_dims = self._flatten_input_dims( - self._feature_dims, feature_dims_names, self._label_dims, - label_dims_names, label_names, has_labels) - - def flatten_features_and_labels(self, features, labels, signals=None): - """Flattens the `features` and `labels` to a single tensor list.""" - self._feature_structure['features'] = features - if labels is not None: - self._feature_structure['labels'] = labels - if signals is not None: - self._feature_structure['signals'] = signals - return data_nest.flatten(self._feature_structure) - - def unflatten_features_and_labels(self, flattened_inputs): - """Restores the flattened inputs to original features and labels form. - - Args: - flattened_inputs: Flattened inputs for each shard. - - Returns: - A tuple of (`features`, `labels`), where `labels` could be None. - Each one, if present, should have identical structure (single tensor vs - dict) as the one returned by input_fn. - - Raises: - ValueError: If the number of expected tensors from `flattened_inputs` - mismatches the recorded structure. - """ - - unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure, - flattened_inputs) - return _Inputs( - unflattened_inputs['features'], - unflattened_inputs.get('labels'), - signals=unflattened_inputs.get('signals')) - - def __init__(self, input_fn, batch_axis, ctx): - """Constructor. - - Args: - input_fn: input fn for train or eval. - batch_axis: A python tuple of int values describing how each tensor - produced by the Estimator `input_fn` should be split across the TPU - compute shards. - ctx: A `_InternalTPUContext` instance with mode. - - Raises: - ValueError: If both `sharded_features` and `num_cores` are `None`. - """ - self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder( - ctx.input_partition_dims) - - self._sharded_per_core = ctx.is_input_sharded_per_core() - self._input_fn = input_fn - self._infeed_queue = None - self._ctx = ctx - self._batch_axis = batch_axis - - def generate_infeed_enqueue_ops_and_dequeue_fn(self): - """Generates infeed enqueue ops and dequeue_fn.""" - # While tf.while_loop is called, the body function, which invokes - # `enqueue_fn` passed in, is called to construct the graph. So, input_fn - # structure is recorded. - enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = ( - self._invoke_input_fn_and_record_structure()) - - self._validate_input_pipeline() - - def dequeue_fn(): - """dequeue_fn is used by TPU to retrieve the tensors.""" - # In the model-parallel case, both the host-side and device-side - # computations must agree on the core on which infeed takes place. We - # choose to perform infeed on logical core 0 of each replica. - values = self._infeed_queue.generate_dequeue_op(tpu_device=0) - # The unflatten process uses the structure information recorded above. - return self._inputs_structure_recorder.unflatten_features_and_labels( - values) - - return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator) - - def _invoke_input_fn_and_record_structure(self): - """Deploys the input pipeline and record input structure.""" - enqueue_ops = [] - infeed_queues = [] - all_dataset_initializers = [] - num_hosts = self._ctx.num_hosts - tpu_host_placement_fn = self._ctx.tpu_host_placement_function - - run_infeed_loop_on_coordinator = True - - if self._sharded_per_core: - # Per-Core input pipeline deployment. - # Invoke input pipeline for each core and placed on the corresponding - # host. - for host_id in range(num_hosts): - host_device = tpu_host_placement_fn(host_id=host_id) - with ops.device(host_device): - with ops.name_scope('input_pipeline_task%d' % (host_id)): - enqueue_ops_fn, captured_infeed_queue = ( - generate_per_core_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, self._inputs_structure_recorder, - host_device, host_id)) - - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - run_infeed_loop_on_coordinator = False - enqueue_ops.append( - _wrap_computation_in_while_loop( - device=host_device, op_fn=enqueue_ops_fn)) - else: - enqueue_ops.append(enqueue_ops_fn()) - # Infeed_queue_getter must be called after enqueue_ops_fn is called. - infeed_queues.append(captured_infeed_queue.get()) - - elif self._ctx.is_input_broadcast_with_iterators(): - # Only calls input_fn in host 0. - host_device = tpu_host_placement_fn(host_id=0) - enqueue_ops_fn, captured_infeed_queue, dataset_initializer = ( - generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn, - self._inputs_structure_recorder, - num_hosts)) - if dataset_initializer: - all_dataset_initializers.append(dataset_initializer) - run_infeed_loop_on_coordinator = False - wrap_fn = ( - _wrap_computation_in_while_loop - if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else - _wrap_computation_in_while_loop_with_stopping_signals) - enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) - else: - enqueue_ops.append(enqueue_ops_fn()) - infeed_queues.append(captured_infeed_queue.get()) - else: - for host_id in range(num_hosts): - host_device = tpu_host_placement_fn(host_id=host_id) - with ops.device(host_device): - with ops.name_scope('input_pipeline_task%d' % (host_id)): - if self._ctx.is_input_per_host_with_iterators(): - enqueue_ops_fn, captured_infeed_queue, dataset_initializer = ( - generate_per_host_v2_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, - self._inputs_structure_recorder, host_device, host_id)) - else: - enqueue_ops_fn, captured_infeed_queue, dataset_initializer = ( - generate_per_host_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, - self._inputs_structure_recorder, self._batch_axis, - host_device, host_id)) - - # NOTE(xiejw): We dispatch here based on the return type of the - # users `input_fn`. - # - # 1. If input_fn returns a Dataset instance, we initialize the - # iterator outside of tf.while_loop, and call the iterator.get_next - # inside tf.while_loop. This should be always safe. - # - # 2. If input_fn returns (features, labels), it is too late to wrap - # them inside tf.while_loop, as resource initialization cannot be - # handled in TF control flow properly. In this case, we will use - # python loop to enqueue the data into TPU system. This may be - # slow compared to the previous case. - if dataset_initializer: - all_dataset_initializers.append(dataset_initializer) - run_infeed_loop_on_coordinator = False - wrap_fn = ( - _wrap_computation_in_while_loop - if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else - _wrap_computation_in_while_loop_with_stopping_signals) - enqueue_ops.append( - wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) - else: - enqueue_ops.append(enqueue_ops_fn()) - infeed_queues.append(captured_infeed_queue.get()) - # infeed_queue is used to generate dequeue ops. The only thing it uses for - # dequeue is dtypes and types. So, any one can be used. Here, grab the - # first one. - self._infeed_queue = infeed_queues[0] - return enqueue_ops, [ - util_lib.MultiHostDatasetInitializerHook(all_dataset_initializers) - ], run_infeed_loop_on_coordinator - - def _validate_input_pipeline(self): - """Validates the input pipeline. - - Perform some sanity checks to log user friendly information. We should - error out to give users better error message. But, if - _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break - user code, so, log a warning. - - Raises: - RuntimeError: If the validation failed. - """ - if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): - err_msg = ('Input pipeline contains one or more QueueRunners. ' - 'It could be slow and not scalable. Please consider ' - 'converting your input pipeline to use `tf.data` instead (see ' - 'https://www.tensorflow.org/guide/datasets for ' - 'instructions.') - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - raise RuntimeError(err_msg) - else: - logging.warn(err_msg) - - -class _ModelFnWrapper(object): - """A `model_fn` wrapper. - - This makes calling model_fn on CPU and TPU easier and more consistent and - performs necessary check and mutation required by TPU training and evaluation. - - In addition, this wrapper manages converting the `model_fn` to a single TPU - train and eval step. - """ - - def __init__(self, model_fn, config, params, ctx): - self._model_fn = model_fn - self._config = config - self._params = params - self._ctx = ctx - - def call_without_tpu(self, features, labels, is_export_mode): - return self._call_model_fn(features, labels, is_export_mode=is_export_mode) - - def convert_to_single_tpu_train_step(self, dequeue_fn): - """Converts user provided model_fn` as a single train step on TPU. - - The user provided `model_fn` takes input tuple - (features, labels) and produces the EstimatorSpec with train_op and loss for - train `mode`. This usually represents a single train computation on CPU. - - For TPU training, a train (computation) step is first wrapped in a - tf.while_loop control flow to repeat for many times and then replicated to - all TPU shards. Besides the input should be taken from TPU infeed rather - than input pipeline (input_fn) directly. To fit TPU loop and replicate - pattern, the original train computation should be reformed, which is the - returned `train_step`. - - Args: - dequeue_fn: The function to retrieve inputs, features and labels, from TPU - infeed dequeue channel. - - Returns: - A tuple of train_fn, host_calls, and captured scaffold_fn. The train_fn - representing the train step for TPU. - """ - - host_call = _OutfeedHostCall(self._ctx) - captured_scaffold_fn = _CapturedObject() - captured_training_hooks = _CapturedObject() - - def train_step(loss): - """Training step function for use inside a while loop.""" - del loss # unused; required in function signature. - inputs = dequeue_fn() - features, labels = inputs.features_and_labels() - - estimator_spec = self._verify_estimator_spec( - self._call_model_fn(features, labels)) - loss, train_op = estimator_spec.loss, estimator_spec.train_op - - if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - captured_scaffold_fn.capture(estimator_spec.scaffold_fn) - else: - captured_scaffold_fn.capture(None) - - captured_training_hooks.capture(estimator_spec.training_hooks) - - tracing_ops = [] - if tensor_tracer.TensorTracer.is_enabled(): - tt = tensor_tracer.TensorTracer() - loss, tracing_ops = tt.trace_tpu(ops.get_default_graph(), loss, - self._ctx.num_replicas) - - # We must run train_op to update the variables prior to running the - # outfeed. - with ops.control_dependencies([train_op]+tracing_ops): - host_call_outfeed_ops = [] - if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access - and estimator_spec.host_call is not None): - host_call.record({'host_call': estimator_spec.host_call}) - host_call_outfeed_ops = host_call.create_enqueue_op() - with ops.control_dependencies(host_call_outfeed_ops): - return array_ops.identity(loss) - - return (train_step, host_call, captured_scaffold_fn, - captured_training_hooks) - - def convert_to_single_tpu_eval_step(self, dequeue_fn): - """Converts user provided model_fn` as a single eval step on TPU. - - Similar to training, the user provided `model_fn` takes input tuple - (features, labels) and produces the TPUEstimatorSpec with eval_metrics for - eval `mode`. This usually represents a single evaluation computation on CPU. - - For TPU evaluation, a eval (computation) step is first wrapped in a - tf.while_loop control flow to repeat for many times and then replicated to - all TPU shards. Besides the input and output are slightly different. Input, - features and labels, should be taken from TPU infeed rather than input - pipeline (input_fn) directly. Output is managed in two stages. First, the - model outputs as the result of evaluation computation, usually model logits, - should be transferred from TPU system to CPU. Then, all model outputs are - concatenated first on CPU and sent to the metric_fn for metrics computation. - To fit TPU evaluation pattern, the original eval computation should be - reformed, which is the returned `eval_step`. - - Args: - dequeue_fn: The function to retrieve inputs, features and labels, from TPU - infeed dequeue channel. - - Returns: - A tuple of eval_fn, host_calls, and captured scaffold_fn. The eval_fn - representing the eval step for TPU. - """ - host_calls = _OutfeedHostCall(self._ctx) - captured_scaffold_fn = _CapturedObject() - captured_eval_hooks = _CapturedObject() - - def eval_step(total_loss): - """Evaluation step function for use inside a while loop.""" - inputs = dequeue_fn() - features, labels = inputs.features_and_labels() - - tpu_estimator_spec = self._call_model_fn(features, labels) - if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - raise RuntimeError( - 'estimator_spec used by TPU evaluation must have type' - '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) - - loss = tpu_estimator_spec.loss - captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) - captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks) - - to_record = {} - if tpu_estimator_spec.eval_metrics: - to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics - if tpu_estimator_spec.host_call is not None: - # We assume that evaluate won't update global step, so we don't wrap - # this host_call. - to_record['host_call'] = tpu_estimator_spec.host_call - host_calls.record(to_record) - - with ops.control_dependencies(host_calls.create_enqueue_op()): - return math_ops.add(total_loss, loss) - - return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks - - def convert_to_single_tpu_predict_step(self, dequeue_fn): - """Converts user provided model_fn` as a single predict step on TPU. - - Args: - dequeue_fn: The function to retrieve inputs, features and labels, from TPU - infeed dequeue channel. - - Returns: - A tuple of predict_fn, host_calls, and captured scaffold_fn. The - predict_fn representing the predict step for TPU. - """ - host_calls = _OutfeedHostCall(self._ctx) - captured_scaffold_fn = _CapturedObject() - captured_predict_hooks = _CapturedObject() - - def predict_step(unused_scalar_stopping_signal): - """Evaluation step function for use inside a while loop.""" - inputs = dequeue_fn() - features, labels = inputs.features_and_labels() - stopping_signals = inputs.signals() - - assert stopping_signals is not None, ( - 'Internal Error: `signals` is missing.') - - tpu_estimator_spec = self._call_model_fn( - features, labels, is_export_mode=False) - if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - raise RuntimeError( - 'estimator_spec used by TPU prediction must have type' - '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) - - self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions) - - captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) - captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks) - to_record = {} - identity_fn = lambda **kwargs: kwargs - to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] - to_record['signals'] = [identity_fn, stopping_signals] - if tpu_estimator_spec.host_call is not None: - to_record['host_call'] = tpu_estimator_spec.host_call - host_calls.record(to_record) - - with ops.control_dependencies(host_calls.create_enqueue_op()): - return _StopSignals.as_scalar_stopping_signal(stopping_signals) - - return (predict_step, host_calls, captured_scaffold_fn, - captured_predict_hooks) - - def _verify_tpu_spec_predictions(self, predictions): - """Validates TPUEstimatorSpec.predictions dict.""" - # TODO(xiejw): Adds validation for prediction dictionrary. - # TODO(xiejw): Adds support for single tensor as predictions. - if not isinstance(predictions, dict): - raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') - - for (key, tensor) in predictions.items(): - if tensor.shape.dims[0].value is None: - raise ValueError( - 'The tensor with key ({}) in TPUEstimatorSpec.predictions has ' - 'dynamic shape (should be static). Tensor: {}'.format(key, tensor)) - return predictions - - def _validate_model_features_and_labels(self, features, labels, - is_export_mode): - """Validates that the features and labels for the model function are valid. - - A valid features/labels object is the one with: - - Type: A tensor or any nested structure of tensors supported by TF nest, - namely nested dictionary, tuple, namedtuple, or sequence of tensors. - - Static shape if is_export_mode is False. - - Args: - features: the features that would be input to the model function. - labels: the labels that would be input to the model function. - is_export_mode: boolean value specifying if in export mode. - - Raises: - TypeError: If features/labels are not of the correct type. - ValueError: If features/labels have dynamic shape. - """ - - def validate(obj, obj_name): - """Helper validate function.""" - if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode): - return - if isinstance(obj, ops.Tensor): - if not obj.get_shape().is_fully_defined(): - raise ValueError( - 'The {} to the model returned by input_fn must have static shape.' - ' Tensor: {}'.format(obj_name, obj)) - else: - for tensor in data_nest.flatten(obj): - if not tensor.get_shape().is_fully_defined(): - raise ValueError( - ('The {} to the model returned by input_fn must have static ' - 'shape. Tensor: {}').format(obj_name, tensor)) - - validate(features, 'features') - if labels is not None: - validate(labels, 'labels') - - def _call_model_fn(self, features, labels, is_export_mode=False): - """Calls the model_fn with required parameters.""" - self._validate_model_features_and_labels(features, labels, is_export_mode) - model_fn_args = function_utils.fn_args(self._model_fn) - kwargs = {} - - # Makes deep copy with `config` and params` in case user mutates them. - config = copy.deepcopy(self._config) - params = copy.deepcopy(self._params) - - if 'labels' in model_fn_args: - kwargs['labels'] = labels - elif labels is not None: - raise ValueError( - 'model_fn does not take labels, but input_fn returns labels.') - if 'mode' in model_fn_args: - kwargs['mode'] = self._ctx.mode - if 'config' in model_fn_args: - kwargs['config'] = config - if 'params' in model_fn_args: - kwargs['params'] = params - - if 'params' not in model_fn_args: - raise ValueError('model_fn ({}) does not include params argument, ' - 'required by TPUEstimator to pass batch size as ' - 'params[\'batch_size\']'.format(self._model_fn)) - - if is_export_mode: - batch_size_for_model_fn = None - else: - batch_size_for_model_fn = self._ctx.batch_size_for_model_fn - - if batch_size_for_model_fn is not None: - _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn) - - running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode) - _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu) - - if not running_on_cpu: - user_context = tpu_context.TPUContext( - internal_ctx=self._ctx, call_from_input_fn=False) - _add_item_to_params(params, _CTX_KEY, user_context) - - estimator_spec = self._model_fn(features=features, **kwargs) - if (running_on_cpu and - isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access - # The estimator_spec will be passed to `Estimator` directly, which expects - # type `EstimatorSpec`. - return estimator_spec.as_estimator_spec() - else: - return estimator_spec - - def _verify_estimator_spec(self, estimator_spec): - """Validates the estimator_spec.""" - if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - return estimator_spec - - err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.' - if estimator_spec.training_chief_hooks: - raise ValueError( - err_msg.format('training_chief_hooks') + 'If you want' + - ' to pass training hooks, please pass via training_hooks.') - - if estimator_spec.scaffold: - logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. ' - 'Please use TPUEstimatorSpec.') - return estimator_spec - - -class _OutfeedHostCall(object): - """Support for `eval_metrics` and `host_call` in TPUEstimatorSpec.""" - - def __init__(self, ctx): - self._ctx = ctx - self._names = [] - # All of these are dictionaries of lists keyed on the name. - self._host_fns = {} - self._tensor_keys = collections.defaultdict(list) - self._tensors = collections.defaultdict(list) - self._tensor_dtypes = collections.defaultdict(list) - self._tensor_shapes = collections.defaultdict(list) - - @staticmethod - def validate(host_calls): - """Validates the `eval_metrics` and `host_call` in `TPUEstimatorSpec`.""" - - for name, host_call in host_calls.items(): - if not isinstance(host_call, (tuple, list)): - raise ValueError('{} should be tuple or list'.format(name)) - if len(host_call) != 2: - raise ValueError('{} should have two elements.'.format(name)) - if not callable(host_call[0]): - raise TypeError('{}[0] should be callable.'.format(name)) - if not isinstance(host_call[1], (tuple, list, dict)): - raise ValueError('{}[1] should be tuple or list, or dict.'.format(name)) - - if isinstance(host_call[1], (tuple, list)): - fullargspec = tf_inspect.getfullargspec(host_call[0]) - fn_args = function_utils.fn_args(host_call[0]) - # wrapped_hostcall_with_global_step uses varargs, so we allow that. - if fullargspec.varargs is None and len(host_call[1]) != len(fn_args): - raise RuntimeError( - 'In TPUEstimatorSpec.{}, length of tensors {} does not match ' - 'method args of the function, which takes {}.'.format( - name, len(host_call[1]), len(fn_args))) - - @staticmethod - def create_cpu_hostcall(host_calls): - """Runs on the host_call on CPU instead of TPU when use_tpu=False.""" - - _OutfeedHostCall.validate(host_calls) - ret = {} - for name, host_call in host_calls.items(): - host_fn, tensors = host_call - if isinstance(tensors, (tuple, list)): - ret[name] = host_fn(*tensors) - else: - # Must be dict. - try: - ret[name] = host_fn(**tensors) - except TypeError as e: - logging.warning( - 'Exception while calling %s: %s. It is likely the tensors ' - '(%s[1]) do not match the ' - 'function\'s arguments', name, e, name) - raise e - return ret - - def record(self, host_calls): - """Records the host_call structure.""" - - for name, host_call in host_calls.items(): - host_fn, tensor_list_or_dict = host_call - self._names.append(name) - self._host_fns[name] = host_fn - - if isinstance(tensor_list_or_dict, dict): - for (key, tensor) in six.iteritems(tensor_list_or_dict): - self._tensor_keys[name].append(key) - self._tensors[name].append(tensor) - self._tensor_dtypes[name].append(tensor.dtype) - self._tensor_shapes[name].append(tensor.shape) - else: - # List or tuple. - self._tensor_keys[name] = None - for tensor in tensor_list_or_dict: - self._tensors[name].append(tensor) - self._tensor_dtypes[name].append(tensor.dtype) - self._tensor_shapes[name].append(tensor.shape) - - def create_enqueue_op(self): - """Create the op to enqueue the recorded host_calls. - - Returns: - A list of enqueue ops, which is empty if there are no host calls. - """ - if not self._names: - return [] - - tensors = [] - # TODO(jhseu): Consider deduping tensors. - for name in self._names: - tensors.extend(self._tensors[name]) - - with ops.device(tpu.core(0)): - return [tpu_ops.outfeed_enqueue_tuple(tensors)] - - def create_tpu_hostcall(self): - """Sends the tensors through outfeed and runs the host_fn on CPU. - - The tensors are concatenated along dimension 0 to form a global tensor - across all shards. The concatenated function is passed to the host_fn and - executed on the first host. - - Returns: - A dictionary mapping name to the return type of the host_call by that - name. - - Raises: - RuntimeError: If outfeed tensor is scalar. - """ - if not self._names: - return {} - - ret = {} - # For each i, dequeue_ops[i] is a list containing the tensors from all - # shards. This list is concatenated later. - dequeue_ops = [] - tensor_dtypes = [] - tensor_shapes = [] - for name in self._names: - for _ in self._tensors[name]: - dequeue_ops.append([]) - for dtype in self._tensor_dtypes[name]: - tensor_dtypes.append(dtype) - for shape in self._tensor_shapes[name]: - tensor_shapes.append(shape) - - # Outfeed ops execute on each replica's first logical core. Note: we must - # constraint it such that we have at most one outfeed dequeue and enqueue - # per replica. - for i in xrange(self._ctx.num_replicas): - host_device, ordinal_id = self._ctx.device_for_replica(i) - with ops.device(host_device): - outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( - dtypes=tensor_dtypes, - shapes=tensor_shapes, - device_ordinal=ordinal_id) - for j, item in enumerate(outfeed_tensors): - dequeue_ops[j].append(item) - - # Deconstruct dequeue ops. - dequeue_ops_by_name = {} - pos = 0 - for name in self._names: - dequeue_ops_by_name[name] = dequeue_ops[pos:pos + - len(self._tensors[name])] - pos += len(self._tensors[name]) - - # It is assumed evaluation always happens on single host TPU system. So, - # place all ops on tpu host if possible. - # - # TODO(jhseu): Evaluate whether this is right for summaries. - with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)): - for name in self._names: - dequeue_ops = dequeue_ops_by_name[name] - for i, item in enumerate(dequeue_ops): - if dequeue_ops[i][0].shape.ndims == 0: - raise RuntimeError( - 'All tensors outfed from TPU should preserve batch size ' - 'dimension, but got scalar {}'.format(dequeue_ops[i][0])) - # TODO(xiejw): Allow users to specify the axis for batch size - # dimension. - dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) - - if self._tensor_keys[name] is not None: - # The user-provided eval_metrics[1] is a dict. - dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops)) - try: - ret[name] = self._host_fns[name](**dequeue_ops) - except TypeError as e: - logging.warning( - 'Exception while calling %s: %s. It is likely the tensors ' - '(%s[1]) do not match the ' - 'function\'s arguments', name, e, name) - raise e - else: - ret[name] = self._host_fns[name](*dequeue_ops) - - return ret - - -class _OutfeedHostCallHook(session_run_hook.SessionRunHook): - """Hook to run host calls when use_tpu=False.""" - - def __init__(self, tensors): - self._tensors = tensors - - def begin(self): - # We duplicate this code from the TPUInfeedOutfeedSessionHook rather than - # create a separate hook to guarantee execution order, because summaries - # need to be initialized before the outfeed thread starts. - # TODO(jhseu): Make a wrapper hook instead? - self._init_ops = contrib_summary.summary_writer_initializer_op() - # Get all the writer resources from the initializer, so we know what to - # flush. - self._finalize_ops = [] - for op in self._init_ops: - self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) - - def after_create_session(self, session, coord): - session.run(self._init_ops) - - def before_run(self, run_context): - return basic_session_run_hooks.SessionRunArgs(self._tensors) - - def end(self, session): - session.run(self._finalize_ops) - - -class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): - """Calculate and report global_step/sec and examples/sec during runtime.""" - - def __init__(self, - batch_size, - every_n_steps=100, - every_n_secs=None, - output_dir=None, - summary_writer=None): - self._batch_size = batch_size - super(ExamplesPerSecondHook, self).__init__( - every_n_steps=every_n_steps, - every_n_secs=every_n_secs, - output_dir=output_dir, - summary_writer=summary_writer) - - def _log_and_record(self, elapsed_steps, elapsed_time, global_step): - global_step_per_sec = elapsed_steps / elapsed_time - examples_per_sec = self._batch_size * global_step_per_sec - if self._summary_writer is not None: - global_step_summary = Summary(value=[ - Summary.Value(tag='global_step/sec', simple_value=global_step_per_sec) - ]) - example_summary = Summary(value=[ - Summary.Value(tag='examples/sec', simple_value=examples_per_sec) - ]) - self._summary_writer.add_summary(global_step_summary, global_step) - self._summary_writer.add_summary(example_summary, global_step) - logging.info('global_step/sec: %g', global_step_per_sec) - logging.info('examples/sec: %g', examples_per_sec) - - -class InstallSignalHandlerHook(session_run_hook.SessionRunHook): - """Change SIGINT (CTRL^C) handler to force quit the process. - - The default behavior often results in hanging processes. - The original handler is restored after training/evaluation. - """ - - def __init__(self): - self._signal_fn = signal.getsignal(signal.SIGINT) - - def before_run(self, run_context): - signal.signal(signal.SIGINT, signal.SIG_DFL) - - def end(self, session): - signal.signal(signal.SIGINT, self._signal_fn) - - -class TPUEstimator(estimator_lib.Estimator): - """Estimator with TPU support. - - TPUEstimator also supports training on CPU and GPU. You don't need to define - a separate `tf.estimator.Estimator`. - - TPUEstimator handles many of the details of running on TPU devices, such as - replicating inputs and models for each core, and returning to host - periodically to run hooks. - - TPUEstimator transforms a global batch size in params to a per-shard batch - size when calling the `input_fn` and `model_fn`. Users should specify - global batch size in constructor, and then get the batch size for each shard - in `input_fn` and `model_fn` by `params['batch_size']`. - - - For training, `model_fn` gets per-core batch size; `input_fn` may get - per-core or per-host batch size depending on `per_host_input_for_training` - in `TPUConfig` (See docstring for TPUConfig for details). - - - For evaluation and prediction, `model_fn` gets per-core batch size and - `input_fn` get per-host batch size. - - Evaluation - ========== - - `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics` - for TPU evaluation. However, if eval_on_tpu is False, `model_fn` must return - `EstimatorSpec` and the evaluation will execute on CPU or GPU; in this case - the following discussion on TPU evaluation does not apply. - - `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where - `tensors` could be a list of any nested structure of `Tensor`s (See - `TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns - a dict from metric string name to the result of calling a metric function, - namely a `(metric_tensor, update_op)` tuple. - - One can set `use_tpu` to `False` for testing. All training, evaluation, and - predict will be executed on CPU. `input_fn` and `model_fn` will receive - `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`. - - Current limitations: - -------------------- - - 1. TPU evaluation only works on a single host (one TPU worker) except - BROADCAST mode. - - 2. `input_fn` for evaluation should **NOT** raise an end-of-input exception - (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all - batches should have the same size. - - Example (MNIST): - ---------------- - - ``` - # The metric Fn which runs on CPU. - def metric_fn(labels, logits): - predictions = tf.argmax(logits, 1) - return { - 'accuracy': tf.metrics.precision( - labels=labels, predictions=predictions), - } - - # Your model Fn which runs on TPU (eval_metrics is list in this example) - def model_fn(features, labels, mode, config, params): - ... - logits = ... - - if mode = tf.estimator.ModeKeys.EVAL: - return tpu_estimator.TPUEstimatorSpec( - mode=mode, - loss=loss, - eval_metrics=(metric_fn, [labels, logits])) - - # or specify the eval_metrics tensors as dict. - def model_fn(features, labels, mode, config, params): - ... - final_layer_output = ... - - if mode = tf.estimator.ModeKeys.EVAL: - return tpu_estimator.TPUEstimatorSpec( - mode=mode, - loss=loss, - eval_metrics=(metric_fn, { - 'labels': labels, - 'logits': final_layer_output, - })) - ``` - - Prediction - ========== - - Prediction on TPU is an experimental feature to support large batch inference. - It is not designed for latency-critical system. In addition, due to some - usability issues, for prediction with small dataset, CPU `.predict`, i.e., - creating a new `TPUEstimator` instance with `use_tpu=False`, might be more - convenient. - - Note: In contrast to TPU training/evaluation, the `input_fn` for prediction - *should* raise an end-of-input exception (`OutOfRangeError` or - `StopIteration`), which serves as the stopping signal to `TPUEstimator`. To be - precise, the ops created by `input_fn` produce one batch of the data. - The `predict()` API processes one batch at a time. When reaching the end of - the data source, an end-of-input exception should be raised by one of these - operations. The user usually does not need to do this manually. As long as the - dataset is not repeated forever, the `tf.data` API will raise an end-of-input - exception automatically after the last batch has been produced. - - Note: Estimator.predict returns a Python generator. Please consume all the - data from the generator so that TPUEstimator can shutdown the TPU system - properly for user. - - Current limitations: - -------------------- - 1. TPU prediction only works on a single host (one TPU worker). - - 2. `input_fn` must return a `Dataset` instance rather than `features`. In - fact, .train() and .evaluate() also support Dataset as return value. - - Example (MNIST): - ---------------- - ``` - height = 32 - width = 32 - total_examples = 100 - - def predict_input_fn(params): - batch_size = params['batch_size'] - - images = tf.random_uniform( - [total_examples, height, width, 3], minval=-1, maxval=1) - - dataset = tf.data.Dataset.from_tensor_slices(images) - dataset = dataset.map(lambda images: {'image': images}) - - dataset = dataset.batch(batch_size) - return dataset - - def model_fn(features, labels, params, mode): - # Generate predictions, called 'output', from features['image'] - - if mode == tf.estimator.ModeKeys.PREDICT: - return tf.contrib.tpu.TPUEstimatorSpec( - mode=mode, - predictions={ - 'predictions': output, - 'is_padding': features['is_padding'] - }) - - tpu_est = TPUEstimator( - model_fn=model_fn, - ..., - predict_batch_size=16) - - # Fully consume the generator so that TPUEstimator can shutdown the TPU - # system. - for item in tpu_est.predict(input_fn=input_fn): - # Filter out item if the `is_padding` is 1. - # Process the 'predictions' - ``` - - Exporting - ========= - - `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`, - and another with `tag_constants.SERVING` and `tag_constants.TPU`. - At serving time, these tags are used to select metagraph to load. - - Before running the graph on TPU, TPU system needs to be initialized. If - TensorFlow Serving model-server is used, this is done automatically. If - not, please call `session.run(tpu.initialize_system())`. - - `tpu.outside_compilation` can be used to wrap TPU incompatible ops in - `model_fn`. - - Example: - ---------------- - - ``` - def model_fn(features, labels, mode, config, params): - ... - logits = ... - export_outputs = { - 'logits': export_output_lib.PredictOutput( - {'logits': logits}) - } - - def host_call(logits): - class_ids = math_ops.argmax(logits) - classes = string_ops.as_string(class_ids) - export_outputs['classes'] = - export_output_lib.ClassificationOutput(classes=classes) - - tpu.outside_compilation(host_call, logits) - - ... - ``` - - """ - - def __init__(self, - model_fn=None, - model_dir=None, - config=None, - params=None, - use_tpu=True, - train_batch_size=None, - eval_batch_size=None, - predict_batch_size=None, - batch_axis=None, - eval_on_tpu=True, - export_to_tpu=True, - warm_start_from=None): - """Constructs an `TPUEstimator` instance. - - Args: - model_fn: Model function as required by `Estimator` which returns - EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks', - and `prediction_hooks` must not capure any TPU Tensor inside the - model_fn. - model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator to - continue training a previously saved model. If `None`, the model_dir in - `config` will be used if set. If both are set, they must be same. If - both are `None`, a temporary directory will be used. - config: An `tpu_config.RunConfig` configuration object. Cannot be `None`. - params: An optional `dict` of hyper parameters that will be passed into - `input_fn` and `model_fn`. Keys are names of parameters, values are - basic python types. There are reserved keys for `TPUEstimator`, - including 'batch_size'. - use_tpu: A bool indicating whether TPU support is enabled. Currently, - - TPU training and evaluation respect this bit, but eval_on_tpu can - override execution of eval. See below. - Predict still happens on CPU. - train_batch_size: An int representing the global training batch size. - TPUEstimator transforms this global batch size to a per-shard batch - size, as params['batch_size'], when calling `input_fn` and `model_fn`. - Cannot be `None` if `use_tpu` is `True`. Must be divisible by total - number of replicas. - eval_batch_size: An int representing evaluation batch size. Must be - divisible by total number of replicas. - predict_batch_size: An int representing the prediction batch size. Must be - divisible by total number of replicas. - batch_axis: A python tuple of int values describing how each tensor - produced by the Estimator `input_fn` should be split across the TPU - compute shards. For example, if your input_fn produced (images, labels) - where the images tensor is in `HWCN` format, your shard dimensions would - be [3, 0], where 3 corresponds to the `N` dimension of your images - Tensor, and 0 corresponds to the dimension along which to split the - labels to match up with the corresponding images. If None is supplied, - and per_host_input_for_training is True, batches will be sharded based - on the major dimension. If tpu_config.per_host_input_for_training is - False or `PER_HOST_V2`, batch_axis is ignored. - eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the - model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. - export_to_tpu: If True, `export_savedmodel()` exports a metagraph for - serving on TPU besides the one on CPU. - warm_start_from: Optional string filepath to a checkpoint or SavedModel to - warm-start from, or a `tf.estimator.WarmStartSettings` object to fully - configure warm-starting. If the string filepath is provided instead of - a `WarmStartSettings`, then all variables are warm-started, and it is - assumed that vocabularies and Tensor names are unchanged. - - Raises: - ValueError: `params` has reserved keys already. - """ - if config is None or not isinstance(config, tpu_config.RunConfig): - raise ValueError( - '`config` must be provided with type `tpu_config.RunConfig`') - - if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS): - raise ValueError('{} are reserved keys but existed in params {}.'.format( - _RESERVED_PARAMS_KEYS, params)) - - if use_tpu: - # Perform some very basic validations. More validations will be found in - # _InternalTPUContext. - if train_batch_size is None: - raise ValueError('`train_batch_size` cannot be `None`') - util_lib.check_positive_integer(train_batch_size, 'train_batch_size') - - if (config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.PER_SHARD_V1 and - config.tpu_config.num_cores_per_replica): - raise ValueError( - 'Model parallelism only supports per host input for training. ' - 'Please adjust TPURunconfig.per_host_input_for_training.') - - if eval_batch_size is not None: - util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size') - - if predict_batch_size is not None: - util_lib.check_positive_integer(predict_batch_size, - 'predict_batch_size') - - # Verifies the model_fn signature according to Estimator framework. - estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access - # We cannot store config and params in this constructor as parent - # constructor might change them, such as assigning a temp dir for - # config.model_dir. - model_function = self._augment_model_fn(model_fn, batch_axis) - - # Overwrite log_step_count_steps to disable TensorLoggingHook and - # StepCounterHook from being created in Estimator. TPUEstimator already - # added equivalent hooks in _augment_model_fn above. - self._log_every_n_steps = config.log_step_count_steps - config = config.replace(log_step_count_steps=None) - - # Passing non-None params as wrapped model_fn has it. - params = params or {} - super(TPUEstimator, self).__init__( - model_fn=model_function, - model_dir=model_dir, - config=config, - params=params, - warm_start_from=warm_start_from) - self._iterations_per_training_loop = ( - self._config.tpu_config.iterations_per_loop) - - # All properties passed to _InternalTPUContext are immutable. - # pylint: disable=protected-access - self._ctx = tpu_context._get_tpu_context( - self._config, train_batch_size, eval_batch_size, predict_batch_size, - use_tpu, eval_on_tpu) - - self._export_to_tpu = export_to_tpu - - self._is_input_fn_invoked = None - self._rendezvous = {} - - def _add_meta_graph_for_mode(self, - builder, - input_receiver_fn_map, - checkpoint_path, - save_variables=True, - mode=model_fn_lib.ModeKeys.PREDICT, - export_tags=None, - check_variables=True): - if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT: - raise NotImplementedError( - 'TPUEstimator only handles mode PREDICT for exporting ' - 'when `export_to_tpu` is `True`; ' - 'got {}.'.format(mode)) - - (super(TPUEstimator, self)._add_meta_graph_for_mode( - builder, - input_receiver_fn_map, - checkpoint_path, - save_variables, - mode=mode, - export_tags=export_tags, - check_variables=check_variables)) - - if self._export_to_tpu: - input_receiver_fn_map = { - _REWRITE_FOR_INFERENCE_MODE: input_receiver_fn_map[mode] - } - export_tags = [tag_constants.SERVING, tag_constants.TPU] - mode = _REWRITE_FOR_INFERENCE_MODE - # See b/110052256 for why `check_variables` is `False`. - (super(TPUEstimator, self)._add_meta_graph_for_mode( - builder, - input_receiver_fn_map, - checkpoint_path, - save_variables=False, - mode=mode, - export_tags=export_tags, - check_variables=False)) - - def _call_model_fn(self, features, labels, mode, config): - if mode == _REWRITE_FOR_INFERENCE_MODE: - return self._call_model_fn_for_inference(features, labels, mode, config) - else: - return super(TPUEstimator, self)._call_model_fn(features, labels, mode, - config) - - def _call_model_fn_for_inference(self, features, labels, mode, config): - """Wraps `_call_model_fn` for `export_savedmodel`.""" - if mode != _REWRITE_FOR_INFERENCE_MODE: - raise ValueError('mode must be {}; ' - 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode)) - - capture = _CapturedObject() - - def computation(): - """Compute tpu tensors used in export_outputs. - - Passed to rewrite so that model_fn will be called under - the rewriting contexts. Only tpu tensors are returned, but export_outputs - and scaffold are captured. - - Returns: - A list of Tensors used in export_outputs and not marked for - outside_compilation. - """ - # We should only call model fn once and it should be inside `computation` - # so that building the graph will happen under `rewrite`. - mode = model_fn_lib.ModeKeys.PREDICT - estimator_spec = self._call_model_fn(features, labels, mode, config) - - # We pick the TPU tensors out from `export_output` and later return them - # from `computation` for rewriting. - tensors_dict = collections.OrderedDict( - (k, _export_output_to_tensors(v)) - for k, v in six.iteritems(estimator_spec.export_outputs)) - tensors = nest.flatten(tensors_dict) - tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)] - - # We cannot return anything other than `tpu_tensors` here so we capture - # the rest for later use. - capture.capture((estimator_spec, tensors_dict, tensors)) - return tpu_tensors - - tpu_tensors_on_cpu = tpu.rewrite(computation) - estimator_spec, tensors_dict, tensors = capture.get() - - # Reconstruct `tensors`, but with `tpu_tensors` replaced with - # `tpu_tensors_on_cpu`. - new_tensors = [] - for t in tensors: - if _is_tpu_tensor(t): - new_tensors.append(tpu_tensors_on_cpu.pop(0)) - elif t is None: - new_tensors.append(None) - else: - # Only fetching `tpu_tensors_on_cpu` does not trigger - # TPU computation and blocks, so we add the control dependency here. - control_inputs = ( - tpu_tensors_on_cpu if _is_iterable(tpu_tensors_on_cpu) else - (tpu_tensors_on_cpu,)) - with ops.control_dependencies(control_inputs): - new_tensors.append(array_ops.identity(t)) - - # Reconstruct `tensors_dict`. - new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) - # Reconstruct `export_outputs`. - export_outputs = estimator_spec.export_outputs - new_export_outputs = collections.OrderedDict( - (k, _clone_export_output_with_tensors(export_outputs[k], v)) - for k, v in six.iteritems(new_tensors_dict)) - - return estimator_spec._replace(export_outputs=new_export_outputs) - - def _create_global_step(self, graph): - """Creates a global step suitable for TPUs. - - Args: - graph: The graph in which to create the global step. - - Returns: - A global step `Tensor`. - - Raises: - ValueError: if the global step tensor is already defined. - """ - return _create_global_step(graph) - - def _convert_train_steps_to_hooks(self, steps, max_steps): - with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx: - if ctx.is_running_on_cpu(): - return super(TPUEstimator, self)._convert_train_steps_to_hooks( - steps, max_steps) - - # On TPU. - if steps is None and max_steps is None: - raise ValueError( - 'For TPU training, one of `steps` or `max_steps` must be set. ' - 'Cannot be both `None`.') - - # Estimator.train has explicit positiveness check. - if steps is not None: - util_lib.check_positive_integer(steps, 'Train steps') - if max_steps is not None: - util_lib.check_positive_integer(max_steps, 'Train max_steps') - - return [ - _TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps) - ] - - def _convert_eval_steps_to_hooks(self, steps): - with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx: - if ctx.is_running_on_cpu(): - return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps) - - if steps is None: - raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.') - - util_lib.check_positive_integer(steps, 'Eval steps') - - return [ - evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access - num_evals=steps), - _SetEvalIterationsHook(steps) - ] - - def _call_input_fn(self, input_fn, mode): - """Calls the input function. - - Args: - input_fn: The input function. - mode: ModeKeys - - Returns: - In TPU mode, returns an input_fn to be called later in model_fn. - Otherwise, calls the input_fn and returns either fatures or - (features, labels). - - Raises: - ValueError: if input_fn takes invalid arguments or does not have `params`. - """ - input_fn_args = function_utils.fn_args(input_fn) - config = self.config # a deep copy. - kwargs = {} - if 'params' in input_fn_args: - kwargs['params'] = self.params # a deep copy. - else: - raise ValueError('input_fn ({}) does not include params argument, ' - 'required by TPUEstimator to pass batch size as ' - 'params["batch_size"]'.format(input_fn)) - if 'config' in input_fn_args: - kwargs['config'] = config - - if 'mode' in input_fn_args: - kwargs['mode'] = mode - - # Records the fact input_fn has been invoked. - self._is_input_fn_invoked = True - - with self._ctx.with_mode(mode) as ctx: - # Setting the batch size in params first. This helps user to have same - # input_fn for use_tpu=True/False. - batch_size_for_input_fn = ctx.batch_size_for_input_fn - if batch_size_for_input_fn is not None: - _add_item_to_params(kwargs['params'], _BATCH_SIZE_KEY, - batch_size_for_input_fn) - - # For export_savedmodel, input_fn is never passed to Estimator. So, - # `is_export_mode` must be False. - if ctx.is_running_on_cpu(is_export_mode=False): - with ops.device('/device:CPU:0'): - return input_fn(**kwargs) - - # For TPU computation, input_fn should be invoked in a tf.while_loop for - # performance. While constructing the tf.while_loop, the structure of - # inputs returned by the `input_fn` needs to be recorded. The structure - # includes whether features or labels is dict or single Tensor, dict keys, - # tensor shapes, and dtypes. The recorded structure is used to create the - # infeed dequeue ops, which must be wrapped and passed as a Fn, called - # inside the TPU computation, as the TPU computation is wrapped inside a - # tf.while_loop also. So, we either pass input_fn to model_fn or pass - # dequeue_fn to model_fn. Here, `input_fn` is passed directly as - # `features` in `model_fn` signature. - def _input_fn(ctx): - _add_item_to_params(kwargs['params'], _CTX_KEY, ctx) - return input_fn(**kwargs) - - return _input_fn - - def _validate_features_in_predict_input(self, result): - """Skip the validation. - - For TPUEstimator, we do not need to check the result type. `_InputPipeline` - has stronger check. Parent class's check generates confusing warning msg. - - Args: - result: `features` returned by input_fn. - """ - pass - - def train(self, - input_fn, - hooks=None, - steps=None, - max_steps=None, - saving_listeners=None): - rendezvous = error_handling.ErrorRendezvous(num_sources=3) - self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous - try: - return super(TPUEstimator, self).train( - input_fn=input_fn, - hooks=hooks, - steps=steps, - max_steps=max_steps, - saving_listeners=saving_listeners) - except Exception: # pylint: disable=broad-except - rendezvous.record_error('training_loop', sys.exc_info()) - finally: - rendezvous.record_done('training_loop') - rendezvous.raise_errors() - - def evaluate(self, - input_fn, - steps=None, - hooks=None, - checkpoint_path=None, - name=None): - rendezvous = error_handling.ErrorRendezvous(num_sources=3) - self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous - try: - return super(TPUEstimator, self).evaluate( - input_fn, - steps=steps, - hooks=hooks, - checkpoint_path=checkpoint_path, - name=name) - except Exception: # pylint: disable=broad-except - rendezvous.record_error('evaluation_loop', sys.exc_info()) - finally: - rendezvous.record_done('evaluation_loop') - rendezvous.raise_errors() - - def predict(self, - input_fn, - predict_keys=None, - hooks=None, - checkpoint_path=None, - yield_single_examples=True): - rendezvous = error_handling.ErrorRendezvous(num_sources=3) - self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous - try: - for result in super(TPUEstimator, self).predict( - input_fn=input_fn, - predict_keys=predict_keys, - hooks=hooks, - checkpoint_path=checkpoint_path, - yield_single_examples=yield_single_examples): - yield result - except Exception: # pylint: disable=broad-except - rendezvous.record_error('prediction_loop', sys.exc_info()) - finally: - rendezvous.record_done('prediction_loop') - rendezvous.raise_errors() - - rendezvous.record_done('prediction_loop') - rendezvous.raise_errors() - - def _augment_model_fn(self, model_fn, batch_axis): - """Returns a new model_fn, which wraps the TPU support.""" - - def _model_fn(features, labels, mode, config, params): - """A Estimator `model_fn` for TPUEstimator.""" - with self._ctx.with_mode(mode) as ctx: - model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx) - - # `input_fn` is called in `train()`, `evaluate()`, and `predict()`, - # but not in `export_savedmodel()`. - if self._is_input_fn_invoked: - is_export_mode = False - else: - is_export_mode = True - - # Clear the bit. - self._is_input_fn_invoked = None - - # examples_hook is added to training_hooks for both CPU and TPU - # execution. - if self._log_every_n_steps is not None: - examples_hook = ExamplesPerSecondHook( - ctx.global_batch_size, - output_dir=self.model_dir, - every_n_steps=self._log_every_n_steps) - - if ctx.is_running_on_cpu(is_export_mode=is_export_mode): - logging.info('Running %s on CPU', mode) - estimator_spec = model_fn_wrapper.call_without_tpu( - features, labels, is_export_mode=is_export_mode) - if self._log_every_n_steps is not None: - estimator_spec = estimator_spec._replace( - training_hooks=estimator_spec.training_hooks + (examples_hook,)) - return estimator_spec - - assert labels is None, '`labels` passed to `model_fn` must be `None`.' - # TPUEstimator._call_input_fn passes `input_fn` as features to here. - assert callable(features), '`input_fn` is not callable.' - input_fn = features - - input_holders = _InputPipeline(input_fn, batch_axis, ctx) - enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = ( - input_holders.generate_infeed_enqueue_ops_and_dequeue_fn()) - - graph = ops.get_default_graph() - for enqueue_op in enqueue_ops: - if isinstance(enqueue_op, list): - graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op) - else: - graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op) - - if mode == model_fn_lib.ModeKeys.TRAIN: - loss, host_call, scaffold, training_hooks = ( - _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) - host_ops = host_call.create_tpu_hostcall() - if host_ops is None: - host_ops = [] - - shutdown_hooks = [] - shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE', - 'shutdown_worker') - if shutdown_mode: - if shutdown_mode == 'shutdown_worker': - finalizer_hooks = [ - session_support.ShutdownLameWorkers(timeout_ms=60 * 1000), - ] - elif shutdown_mode == 'shutdown_computation': - finalizer_hooks = [ - session_support.RestartComputation(timeout_ms=60 * 1000), - ] - else: - raise ValueError( - 'Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' % shutdown_mode) - - shutdown_hooks.append( - session_support.GracefulShutdownHook( - checkpoint_prefix=self.model_dir + '/model.ckpt', - on_shutdown_hooks=finalizer_hooks)) - - with ops.control_dependencies([loss]): - global_step = array_ops.identity(training.get_global_step()) - hooks = input_hooks + shutdown_hooks - hooks.extend([ - TPUInfeedOutfeedSessionHook( - ctx, - enqueue_ops, - host_ops, - run_infeed_loop_on_coordinator=( - run_infeed_loop_on_coordinator), - rendezvous=self._rendezvous[mode], - master=self._config.master, - session_config=self._session_config, - ), - InstallSignalHandlerHook() - ]) - if self._log_every_n_steps is not None: - logging_hook_frequency = ( # Divide and round up - (self._log_every_n_steps + - self._config.tpu_config.iterations_per_loop - 1) // - self._config.tpu_config.iterations_per_loop) - hooks.append( - training.LoggingTensorHook({ - 'loss': array_ops.identity(loss), - 'step': global_step, - }, - every_n_iter=logging_hook_frequency)) - examples_hook._set_steps_per_run( # pylint: disable=protected-access - self._config.tpu_config.iterations_per_loop) - hooks.append(examples_hook) - - if training_hooks: - hooks.extend(training_hooks) - - chief_hooks = [] - if (self._config.save_checkpoints_secs or - self._config.save_checkpoints_steps): - checkpoint_hook = training.CheckpointSaverHook( - self.model_dir, - save_secs=self._config.save_checkpoints_secs, - save_steps=self._config.save_checkpoints_steps, - scaffold=scaffold) - checkpoint_hook._set_steps_per_run( # pylint: disable=protected-access - self._config.tpu_config.iterations_per_loop) - chief_hooks.append(checkpoint_hook) - - summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) - with ops.control_dependencies([loss]): - update_ops = _sync_variables_ops(ctx) - - # Validate the TPU training graph to catch basic errors - _validate_tpu_training_graph() - - train_op = control_flow_ops.group(*update_ops) - graph.add_to_collection(_TPU_TRAIN_OP, train_op) - - return model_fn_lib.EstimatorSpec( - mode, - loss=loss, - training_chief_hooks=chief_hooks, - training_hooks=hooks, - train_op=train_op, - scaffold=scaffold) - - if mode == model_fn_lib.ModeKeys.EVAL: - total_loss, host_calls, scaffold, eval_hooks = _eval_on_tpu_system( - ctx, model_fn_wrapper, dequeue_fn) - iterations_per_loop_var = _create_or_get_iterations_per_loop() - mean_loss = math_ops.div( - total_loss, - math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype)) - - with ops.control_dependencies([mean_loss]): - # After TPU evaluation computation is done (the mean_loss tensor), - # reads all variables back from TPU and updates the eval step - # counter properly - internal_ops_to_run = _sync_variables_ops(ctx) - internal_ops_to_run.append( - _increase_eval_step_op(iterations_per_loop_var)) - - host_call_ret = host_calls.create_tpu_hostcall() - eval_metric_ops = {} - eval_update_ops = [] - - eval_metrics = host_call_ret.get('eval_metrics', {}) - if eval_metrics: - # Creates a dummy metric update_op for all metrics. Estimator - # expects all metrics in `eval_metric_ops` have update_op and calls - # them one by one. The real metric update_ops are invoked in a - # separated thread. So, here give Estimator the dummy op for all - # metrics. - with ops.control_dependencies(internal_ops_to_run): - dummy_update_op = control_flow_ops.no_op() - - for k, v in eval_metrics.items(): - eval_metric_ops[k] = (v[0], dummy_update_op) - eval_update_ops.append(v[1]) - else: - # If no eval metrics are passed, create an identity node for the - # loss and add `internal_ops_to_run` to its dependencies. So - # `internal_ops_to_run` can be executed. - with ops.control_dependencies(internal_ops_to_run): - mean_loss = array_ops.identity(mean_loss) - - if 'host_call' not in host_call_ret: - host_ops = [] - else: - host_ops = host_call_ret['host_call'] - hooks = [ - TPUInfeedOutfeedSessionHook( - ctx, - enqueue_ops, - eval_update_ops + host_ops, - run_infeed_loop_on_coordinator=( - run_infeed_loop_on_coordinator), - rendezvous=self._rendezvous[mode], - master=self._config.evaluation_master, - session_config=self._session_config, - )] + input_hooks - - if eval_hooks: - hooks.extend(eval_hooks) - - return model_fn_lib.EstimatorSpec( - mode, - loss=mean_loss, - evaluation_hooks=hooks, - eval_metric_ops=eval_metric_ops, - scaffold=scaffold) - - # Predict - assert mode == model_fn_lib.ModeKeys.PREDICT - - (dummy_predict_op, host_calls, - scaffold, prediction_hooks) = _predict_on_tpu_system( - ctx, model_fn_wrapper, dequeue_fn) - with ops.control_dependencies([dummy_predict_op]): - internal_ops_to_run = _sync_variables_ops(ctx) - with ops.control_dependencies(internal_ops_to_run): - dummy_predict_op = control_flow_ops.no_op() - - # In train and evaluation, the main TPU program is passed to monitored - # training session to run. Infeed enqueue and outfeed dequeue are - # executed in side threads. This is not the configuration for - # prediction mode. - # - # For prediction, the Estimator executes the EstimatorSpec.predictions - # directly and yield the element (via generator) to call site. So, the - # outfeed based prediction must be passed to MonitoredSession directly. - # Other parts of the TPU execution are organized as follows. - # - # 1. All outfeed based Tensors must be grouped with predictions Tensors - # to form a single invocation. This avoid the issue we might trigger - # multiple outfeeds incorrectly. To achieve this, `host_call` is - # placed in control_dependencies of `stopping_signals`, and - # `stopping_signals` is passed into _StoppingPredictHook, which sets - # the `stopping_signals` as SessionRunArgs. MonitoredSession merges - # all SessionRunArgs with the fetch in session.run together. - # - # 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue) - # are grouped together. They will be launched once and only once in - # side threads and they quit naturally according to the SAME stopping - # condition. - enqueue_ops.append(dummy_predict_op) - - host_call_ret = host_calls.create_tpu_hostcall() - if 'host_call' not in host_call_ret: - host_ops = [] - else: - host_ops = host_call_ret['host_call'] - - predictions = host_call_ret['predictions'] - _verify_cross_hosts_transfer_size( - predictions, - message=( - 'The estimated size for TPUEstimatorSpec.predictions is too ' - 'large.')) - signals = host_call_ret['signals'] - - with ops.control_dependencies(host_ops): - host_ops = [] # Empty, we do do not need it anymore. - scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal( - signals) - predictions = _PaddingSignals.slice_tensor_or_dict( - predictions, signals) - - hooks = [ - _StoppingPredictHook(scalar_stopping_signal), - TPUInfeedOutfeedSessionHookForPrediction( - ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode], - master=self._config.master, - session_config=self._session_config), - ] + input_hooks - - if prediction_hooks: - hooks.extend(prediction_hooks) - - return model_fn_lib.EstimatorSpec( - mode, - prediction_hooks=hooks, - predictions=predictions, - scaffold=scaffold) - - return _model_fn - - -def _is_tpu_tensor(tensor): - if not isinstance(tensor, ops.Tensor): - return False - try: - tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access - except ValueError: - return True - else: - return False - - -def _export_output_to_tensors(export_output): - """Get a list of `Tensors` used in `export_output`. - - Args: - export_output: an `ExportOutput` object such as `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - - Returns: - a list of tensors used in export_output. - - Raises: - ValueError: if `export_output` is not one of `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - """ - if isinstance(export_output, export_output_lib.ClassificationOutput): - return [export_output.scores, export_output.classes] - elif isinstance(export_output, export_output_lib.RegressionOutput): - return [export_output.value] - elif isinstance(export_output, export_output_lib.PredictOutput): - return list(export_output.outputs.values()) - else: - raise ValueError( - '`export_output` must be have type `ClassificationOutput`, ' - '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) - - -def _clone_export_output_with_tensors(export_output, tensors): - """Clones `export_output` but with new `tensors`. - - Args: - export_output: an `ExportOutput` object such as `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - tensors: a list of `Tensors` used to construct a new `export_output`. - - Returns: - A dict similar to `export_output` but with `tensors`. - - Raises: - ValueError: if `export_output` is not one of `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - """ - if isinstance(export_output, export_output_lib.ClassificationOutput): - if len(tensors) != 2: - raise ValueError('tensors must be of length 2; ' - 'got {}.'.format(len(tensors))) - return export_output_lib.ClassificationOutput(*tensors) - elif isinstance(export_output, export_output_lib.RegressionOutput): - if len(tensors) != 1: - raise ValueError('tensors must be of length 1; ' - 'got {}'.format(len(tensors))) - return export_output_lib.RegressionOutput(*tensors) - elif isinstance(export_output, export_output_lib.PredictOutput): - return export_output_lib.PredictOutput( - dict(zip(export_output.outputs.keys(), tensors))) - else: - raise ValueError( - '`export_output` must be have type `ClassificationOutput`, ' - '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) - - -def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): - """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - iterations_per_loop_var = _create_or_get_iterations_per_loop() - - (single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks - ) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn) - - def multi_tpu_eval_steps_on_single_shard(): - return training_loop.repeat(iterations_per_loop_var, single_tpu_eval_step, - [_ZERO_LOSS]) - - (loss,) = tpu.shard( - multi_tpu_eval_steps_on_single_shard, - inputs=[], - num_shards=ctx.num_replicas, - outputs_from_all_shards=False, - device_assignment=ctx.device_assignment) - - scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_calls, scaffold, captured_eval_hooks.get() - - -def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): - """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - iterations_per_loop_var = _create_or_get_iterations_per_loop() - - (single_tpu_train_step, host_call, captured_scaffold_fn, - captured_training_hooks) = ( - model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) - - def multi_tpu_train_steps_on_single_shard(): - return training_loop.repeat(iterations_per_loop_var, single_tpu_train_step, - [_INITIAL_LOSS]) - - (loss,) = tpu.shard( - multi_tpu_train_steps_on_single_shard, - inputs=[], - num_shards=ctx.num_replicas, - outputs_from_all_shards=False, - device_assignment=ctx.device_assignment) - - scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_call, scaffold, captured_training_hooks.get() - - -def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): - """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - (single_tpu_predict_step, host_calls, captured_scaffold_fn, - captured_predict_hooks - ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn) - - def multi_tpu_predict_steps_on_single_shard(): - - def cond(scalar_stopping_signal): - return math_ops.logical_not( - _StopSignals.should_stop(scalar_stopping_signal)) - - inputs = [_StopSignals.NON_STOPPING_SIGNAL] - outputs = training_loop.while_loop( - cond, single_tpu_predict_step, inputs=inputs, name=b'loop') - return outputs - - (dummy_predict_op,) = tpu.shard( - multi_tpu_predict_steps_on_single_shard, - inputs=[], - num_shards=ctx.num_replicas, - outputs_from_all_shards=False, - device_assignment=ctx.device_assignment) - - scaffold = _get_scaffold(captured_scaffold_fn) - return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get() - - -def _wrap_computation_in_while_loop(device, op_fn): - """Wraps the ops generated by `op_fn` in tf.while_loop.""" - - def computation(i): - with ops.control_dependencies(op_fn()): - return i + 1 - - iterations_per_loop_var = _create_or_get_iterations_per_loop() - # By setting parallel_iterations=1, the parallel execution in while_loop is - # basically turned off. - with ops.device(device): - iterations = array_ops.identity(iterations_per_loop_var) - return control_flow_ops.while_loop( - lambda i: i < iterations, - computation, [constant_op.constant(0)], - parallel_iterations=1) - - -def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn): - """Wraps the ops generated by `op_fn` in tf.while_loop.""" - - def cond(scalar_stopping_signal): - return math_ops.logical_not( - _StopSignals.should_stop(scalar_stopping_signal)) - - def computation(unused_scalar_stopping_signal): - return_value = op_fn() - execute_ops = return_value['ops'] - signals = return_value['signals'] - with ops.control_dependencies(execute_ops): - return _StopSignals.as_scalar_stopping_signal(signals) - - # By setting parallel_iterations=1, the parallel execution in while_loop is - # basically turned off. - with ops.device(device): - return control_flow_ops.while_loop( - cond, - computation, [_StopSignals.NON_STOPPING_SIGNAL], - parallel_iterations=1) - - -def _validate_tpu_training_graph(): - """Validate graph before running distributed training. - - Raises: - ValueError: If the graph seems invalid for running on device - """ - operations = ops.get_default_graph().get_operations() - - # Check if there is atleast one CrossReplicaSum operation in the graph - # This should be introduced by using the CrossShardOptimizer wrapper - cross_replica_sum_ops = [ - o for o in operations if o.type == _CROSS_REPLICA_SUM_OP - ] - if not cross_replica_sum_ops: - raise ValueError( - 'CrossShardOptimizer must be used for model training on TPUs.') - - -class _CapturedObject(object): - """A placeholder to capture an object. - - This is useful when we need to capture a Python object in the Tensorflow - control flow body function and use it outside the control flow. - """ - - def __init__(self): - self._object = None - self._captured = False - - def capture(self, o): - if self._captured: - raise RuntimeError( - 'InternalError: Object can capture only once. Please file bug.') - - self._captured = True - self._object = o - - def get(self): - if not self._captured: - raise RuntimeError( - 'InternalError: Object is not captured properly before `get`. ' - 'Please file bug.') - return self._object - - -def _get_scaffold(captured_scaffold_fn): - """Retrieves the Scaffold from `captured_scaffold_fn`.""" - with _CapturingContext(message='Inside scaffold_fn'): - scaffold_fn = captured_scaffold_fn.get() - if scaffold_fn: - scaffold = scaffold_fn() - if scaffold is None: - raise ValueError( - 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') - else: - scaffold = None - - if scaffold: - wrapped_finalize = scaffold.finalize - - def _finalize(): - with _CapturingContext('Inside Scaffold.finalize'): - wrapped_finalize() - - scaffold.finalize = _finalize - return scaffold - - -class _CapturingContext(control_flow_ops.ControlFlowContext): - """Tracks references to Tensors defined in TPU replication.""" - - def __init__(self, message): - control_flow_ops.ControlFlowContext.__init__(self) - self._message = message - - def to_control_flow_context_def(self, context_def, export_scope=None): - # pylint: disable=useless-super-delegation - # NOTE(slebedev): the method is required by `ControlFlowContext`. - super(_CapturingContext, self).to_control_flow_context_def( - context_def, export_scope) - - def AddOp(self, op): # pylint: disable=invalid-name - for c in op.inputs: - if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr: # pylint: disable=protected-access - raise ValueError('{}: Op {} depends on TPU computation {}, ' - 'which is not allowed.'.format(self._message, op, c)) - - def __enter__(self): - # pylint: disable=protected-access - self._g = ops.get_default_graph() - self._old = self._g._get_control_flow_context() - self._g._set_control_flow_context(self) - # pylint: enable=protected-access - - def __exit__(self, _, __, ___): # pylint: disable=invalid-name - self._g._set_control_flow_context(self._old) # pylint: disable=protected-access - - -class _Inputs(object): - """A data structure representing the input_fn returned values. - - This also supports the returned value from input_fn as `Dataset`. - """ - - def __init__(self, features=None, labels=None, dataset=None, signals=None): - if dataset is not None and (features is not None or labels is not None or - signals is not None): - raise RuntimeError('Internal Error: Either (features and labels) or ' - 'dataset should be provided, not both. Please file ' - 'bug') - - self._features = features - self._labels = labels - self._signals = signals - - self._dataset = dataset - self._iterator = None - - @staticmethod - def from_input_fn(return_values): - """Returns an `_Inputs` instance according to `input_fn` return value.""" - if isinstance(return_values, dataset_ops.DatasetV2): - dataset = return_values - return _Inputs(dataset=dataset) - - features, labels = _Inputs._parse_inputs(return_values) - return _Inputs(features, labels) - - @staticmethod - def _parse_inputs(return_values): - if isinstance(return_values, tuple): - features, labels = return_values - else: - features, labels = return_values, None - return features, labels - - @property - def is_dataset(self): - """Returns True if the return value from input_fn is Dataset.""" - return self._dataset is not None - - def dataset_initializer(self): - """Returns the dataset's initializer. - - The initializer must be run before calling `features_and_labels`. - """ - self._iterator = dataset_ops.make_initializable_iterator(self._dataset) - return self._iterator.initializer - - def features_and_labels(self): - """Gets `features` and `labels`.""" - if self.is_dataset: - if self._iterator is None: - raise RuntimeError('Internal error: Must run dataset_initializer ' - 'before calling features_and_labels(). Please file ' - 'a bug!') - return _Inputs._parse_inputs(self._iterator.get_next()) - - return (self._features, self._labels) - - def signals(self): - return self._signals - - @property - def dataset(self): - return self._dataset - - -class _InputsWithStoppingSignals(_Inputs): - """Inputs with `_StopSignals` inserted into the dataset.""" - - def __init__(self, - dataset, - batch_size, - add_padding=False, - num_invocations_per_step=1): - - assert dataset is not None - user_provided_dataset = dataset.map( - _InputsWithStoppingSignals.insert_stopping_signal( - stop=False, batch_size=batch_size, add_padding=add_padding)) - if num_invocations_per_step == 1: - final_batch_dataset = dataset.take(1).map( - _InputsWithStoppingSignals.insert_stopping_signal( - stop=True, batch_size=batch_size, add_padding=add_padding)) - else: - # We append (2 * num_invocations_per_step - 1) batches for exhausting the - # user_provided_dataset and stop properly. - # For example, if num_invocations_per_step is 2, we append 3 additional - # padding batches: b1, b2, b3. - # If user_provided_dataset contains two batches: a1, a2 - # Step 1: [a1, a2] - # Step 2: [b1, b2] -> STOP - # If user_provided_dataset contains three batches: a1, a2, a3. - # The training loops: - # Step 1: [a1, a2] - # Step 2: [a3, b1] - # Step 3: [b2, b3] -> STOP. - final_batch_dataset = dataset.take(1).map( - _InputsWithStoppingSignals.insert_stopping_signal( - stop=True, batch_size=batch_size, add_padding=add_padding)) - final_batch_dataset = final_batch_dataset.repeat( - 2 * num_invocations_per_step - 1) - - def _set_mask(data_dict): - signals = data_dict['signals'] - signals['padding_mask'] = array_ops.ones_like(signals['padding_mask']) - data_dict['signals'] = signals - return data_dict - - # Mask out the extra batch. - final_batch_dataset = final_batch_dataset.map(_set_mask) - - dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2) - - super(_InputsWithStoppingSignals, self).__init__(dataset=dataset) - self._current_inputs = None - - def features_and_labels(self): - if self._current_inputs is not None: - raise RuntimeError( - 'Internal Error: The previous inputs have not been properly ' - 'consumed. First call features_and_labels, then call signals.') - - inputs_with_signals = self._iterator.get_next() - features = inputs_with_signals['features'] - labels = inputs_with_signals.get('labels') - - self._current_inputs = inputs_with_signals - return features, labels - - def signals(self): - """Returns the `Signals` from `_Inputs`.""" - if self._current_inputs is None: - raise RuntimeError( - 'Internal Error: The current inputs have not been properly ' - 'generated. First call features_and_labels, then call signals.') - signals = self._current_inputs['signals'] - self._current_inputs = None - return signals - - @staticmethod - def insert_stopping_signal(stop, batch_size, add_padding=False): - """Inserts stopping_signal into dataset via _map_fn. - - Here we change the data structure in the dataset, such that the return value - is a dictionary now and `features`, `labels`, and `signals` are three - distinguished keys in that dict. This provides a better structure, which - eases the process to decompose the inputs (see `features_and_labels`). - - Args: - stop: bool, state of current stopping signals. - batch_size: int, batch size. - add_padding: bool, whether to pad the tensor to full batch size. - - Returns: - A map_fn passed to dataset.map API. - """ - - def _map_fn(*args): - """The map fn to insert signals.""" - if len(args) == 1: - # Unpack the single Tensor/dict argument as features. This is required - # for the input_fn returns no labels. - args = args[0] - features, labels = _Inputs._parse_inputs(args) - new_input_dict = {} - - if add_padding: - padding_mask, features, labels = ( - _PaddingSignals.pad_features_and_labels(features, labels, - batch_size)) - - new_input_dict['features'] = features - if labels is not None: - new_input_dict['labels'] = labels - - else: - new_input_dict['features'] = features - if labels is not None: - new_input_dict['labels'] = labels - padding_mask = None - - new_input_dict['signals'] = _StopSignals( - stop=stop, batch_size=batch_size, - padding_mask=padding_mask).as_dict() - - return new_input_dict - - return _map_fn - - -class _StopSignals(object): - """Signals class holding all logic to handle TPU stopping condition.""" - - NON_STOPPING_SIGNAL = False - STOPPING_SIGNAL = True - - def __init__(self, stop, batch_size, padding_mask=None): - self._stop = stop - self._batch_size = batch_size - self._padding_mask = padding_mask - - def as_dict(self): - """Returns the signals as Python dict.""" - shape = [self._batch_size, 1] - dtype = dtypes.bool - - if self._stop: - stopping = array_ops.ones(shape=shape, dtype=dtype) - else: - stopping = array_ops.zeros(shape=shape, dtype=dtype) - - signals = {'stopping': stopping} - if self._padding_mask is not None: - signals['padding_mask'] = self._padding_mask - return signals - - @staticmethod - def as_scalar_stopping_signal(signals): - return array_ops.identity(signals['stopping'][0][0]) - - @staticmethod - def should_stop(scalar_stopping_signal): - """Detects whether scalar_stopping_signal indicates stopping.""" - if isinstance(scalar_stopping_signal, ops.Tensor): - # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF - # way to express the bool check whether scalar_stopping_signal is True. - return math_ops.logical_and(scalar_stopping_signal, - _StopSignals.STOPPING_SIGNAL) - else: - # For non Tensor case, it is used in SessionRunHook. So, we cannot modify - # the graph anymore. Here, we use pure Python. - return bool(scalar_stopping_signal) - - -class _PaddingSignals(object): - """Signals class holding all logic to handle padding.""" - - @staticmethod - def pad_features_and_labels(features, labels, batch_size): - """Pads out the batch dimension of features and labels.""" - real_batch_size = array_ops.shape( - _PaddingSignals._find_any_tensor(features))[0] - - batch_size_tensor = constant_op.constant(batch_size, dtypes.int32) - - check_greater = check_ops.assert_greater_equal( - batch_size_tensor, - real_batch_size, - data=(batch_size_tensor, real_batch_size), - message='The real batch size should not be greater than batch_size.') - - with ops.control_dependencies([check_greater]): - missing_count = batch_size_tensor - real_batch_size - - def pad_single_tensor(tensor): - """Pads out the batch dimension of a tensor to the complete batch_size.""" - rank = len(tensor.shape) - assert rank > 0 - padding = array_ops.stack([[0, missing_count]] + [[0, 0]] * (rank - 1)) - padded_shape = (batch_size,) + tuple(tensor.shape[1:]) - padded_tensor = array_ops.pad(tensor, padding) - padded_tensor.set_shape(padded_shape) - return padded_tensor - - def nest_pad(tensor_or_dict): - return nest.map_structure(pad_single_tensor, tensor_or_dict) - - features = nest_pad(features) - if labels is not None: - labels = nest_pad(labels) - - padding_mask = _PaddingSignals._padding_mask(real_batch_size, missing_count, - batch_size) - - return padding_mask, features, labels - - @staticmethod - def slice_tensor_or_dict(tensor_or_dict, signals): - """Slice the real Tensors according to padding mask in signals.""" - - padding_mask = signals['padding_mask'] - batch_size = array_ops.shape(padding_mask)[0] - - def verify_batch_size(tensor): - check_batch_size = math_ops.equal(batch_size, tensor.shape[0]) - with ops.control_dependencies([check_batch_size]): - return array_ops.identity(tensor) - - def slice_single_tensor(tensor): - rank = len(tensor.shape) - assert rank > 0 - real_batch_size = batch_size - math_ops.reduce_sum(padding_mask) - return verify_batch_size(tensor)[0:real_batch_size] - - # As we split the Tensors to all TPU cores and concat them back, it is - # important to ensure the real data is placed before padded ones, i.e., - # order is preserved. By that, the sliced padding mask should have all 0's. - # If this assertion failed, # the slice logic here would not hold. - sliced_padding_mask = slice_single_tensor(padding_mask) - assert_padding_mask = math_ops.equal( - math_ops.reduce_sum(sliced_padding_mask), 0) - - with ops.control_dependencies([assert_padding_mask]): - should_stop = _StopSignals.should_stop( - _StopSignals.as_scalar_stopping_signal(signals)) - - is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0) - - def slice_fn(tensor): - # If the current batch is full batch or part of stopping signals, we do - # not need to slice to save performance. - return control_flow_ops.cond( - math_ops.logical_or(should_stop, is_full_batch), - (lambda: verify_batch_size(tensor)), - (lambda: slice_single_tensor(tensor))) - - return nest.map_structure(slice_fn, tensor_or_dict) - - @staticmethod - def _find_any_tensor(batch_features): - tensors = [ - x for x in nest.flatten(batch_features) if isinstance(x, ops.Tensor) - ] - if not tensors: - raise ValueError('Cannot find any Tensor in features dict.') - return tensors[0] - - @staticmethod - def _padding_mask(real_batch_size, missing_count, batch_size): - padding_mask = array_ops.concat([ - array_ops.zeros((real_batch_size,), dtype=dtypes.int32), - array_ops.ones((missing_count,), dtype=dtypes.int32) - ], - axis=0) - padding_mask.set_shape((batch_size,)) - return padding_mask - - -def _verify_cross_hosts_transfer_size(tensor_dict, message): - total_size = 0 - tensor_structure = {} - for key, tensor in tensor_dict.items(): - shape = tensor.shape - size = np.product(shape) * tensor.dtype.size - tensor_structure[key] = shape - total_size += size - if total_size >= _ONE_GIGABYTE: - raise ValueError( - '{} The transfer size is larger than the protobuf limit. Please ' - 'consider to use Tensors with smaller shapes or reduce batch ' - 'size. Given:\n' - '{}'.format( - message, '\n'.join([ - ' -- Key: {}, Shape: {}'.format(k, v) - for k, v in tensor_structure.items() - ]))) - - -def _add_item_to_params(params, key, value): - """Adds a new item into `params`.""" - if isinstance(params, hparam.HParams): - # For HParams, we need to use special API. - if key in params: - params.set_hparam(key, value) - else: - params.add_hparam(key, value) - else: - # Now params is Python dict. - params[key] = value - - -def export_estimator_savedmodel(estimator, - export_dir_base, - serving_input_receiver_fn, - assets_extra=None, - as_text=False, - checkpoint_path=None, - strip_default_attrs=False): - """Export `Estimator` trained model for TPU inference. - - Args: - estimator: `Estimator` with which model has been trained. - export_dir_base: A string containing a directory in which to create - timestamped subdirectories containing exported SavedModels. - serving_input_receiver_fn: A function that takes no argument and returns a - `ServingInputReceiver` or `TensorServingInputReceiver`. - assets_extra: A dict specifying how to populate the assets.extra directory - within the exported SavedModel, or `None` if no extra assets are needed. - as_text: whether to write the SavedModel proto in text format. - checkpoint_path: The checkpoint path to export. If `None` (the default), - the most recent checkpoint found within the model directory is chosen. - strip_default_attrs: Boolean. If `True`, default-valued attributes will be - removed from the NodeDefs. - - Returns: - The string path to the exported directory. - """ - # `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use - # `estimator.config`. - config = tpu_config.RunConfig(model_dir=estimator.model_dir) - est = TPUEstimator( - estimator._model_fn, # pylint: disable=protected-access - config=config, - params=estimator.params, - use_tpu=True, - train_batch_size=2048, # Does not matter. - eval_batch_size=2048, # Does not matter. - ) - return est.export_savedmodel(export_dir_base, serving_input_receiver_fn, - assets_extra, as_text, checkpoint_path, - strip_default_attrs) +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.tpu_estimator import * +# used by tests +from tensorflow.python.tpu.tpu_estimator import _clone_export_output_with_tensors +from tensorflow.python.tpu.tpu_estimator import _create_global_step +from tensorflow.python.tpu.tpu_estimator import _export_output_to_tensors +from tensorflow.python.tpu.tpu_estimator import _get_scaffold +from tensorflow.python.tpu.tpu_estimator import _Inputs +from tensorflow.python.tpu.tpu_estimator import _ITERATIONS_PER_LOOP_VAR +from tensorflow.python.tpu.tpu_estimator import _TPU_ENQUEUE_OPS +from tensorflow.python.tpu.tpu_estimator import _TPU_ESTIMATOR +from tensorflow.python.tpu.tpu_estimator import _TPU_TRAIN_OP +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py index d5957b7e8ec40b40c7af8822378cee6134ef0d0f..af2542ea85290170ce6a38223188c4f9b871f032 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py @@ -1,898 +1,25 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== - -"""Helper library for handling infeed between hosts and TPUs. -""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools - -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.contrib.tpu.python.tpu import tpu_sharding - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.util import nest - - -class InfeedQueue(object): - """A helper object to build a device infeed queue. - - The InfeedQueue builds the host-side and device-side Ops to enqueue and - dequeue elements, respectively, and ensures that their types and - shapes match. - """ - - def __init__(self, - number_of_tuple_elements=None, - tuple_types=None, - tuple_shapes=None, - shard_dimensions=None, - name=None): - """Creates a new InfeedQueue with the given configuration. - - The configuration need not be fully specified at creation since it - can be modified subsequently by methods that set the values - explicitly or infer them from the shapes of inputs. - - Args: - number_of_tuple_elements: the number of Tensors fed atomically through the - queue, must be present unless it can be inferred from other arguments. - tuple_types: if not None, a list of types of the elements of the queue. - tuple_shapes: if not None, a list of shapes of the elements of the queue. - shard_dimensions: if not None, a list of dimensions on which the - elements of the queue should be sharded during automatic - parallelization. - name: the name of the queue. - - Raises: - ValueError: if number_of_tuple_elements <= 0; or - number_of_tuple_arguments, tuple_types, tuple_shapes, and - shard_dimensions are all None; or the length of tuple_types, - tuple_shapes, or shard_dimensions is not equal to - number_of_tuple_elements; or any element of shard_dimensions - can't be converted to a Dimension. - TypeError: if any element of tuple_types or tuple_shapes can't - be converted to a dtype or TensorShape, respectively. - """ - self._frozen = False - self._generated_enqueue_ops = False - self._generated_dequeue_op = False - self._name = "InfeedQueue" if name is None else name - if number_of_tuple_elements is None: - if tuple_types is not None: - number_of_tuple_elements = len(tuple_types) - elif tuple_shapes is not None: - number_of_tuple_elements = len(tuple_shapes) - elif shard_dimensions is not None: - number_of_tuple_elements = len(shard_dimensions) - else: - raise ValueError( - "number of tuple elements cannot be inferred from InfeedQueue " - "constructor") - if number_of_tuple_elements <= 0: - raise ValueError("number_of_tuple_elements %d must be > 0" % - number_of_tuple_elements) - # Make an empty sharding policy for each tuple element. - self._sharding_policies = [ - tpu_sharding.ShardingPolicy() - for _ in xrange(number_of_tuple_elements) - ] - if tuple_types is not None: - self.set_tuple_types(tuple_types) - else: - self._tuple_types = None - if tuple_shapes is not None: - self.set_tuple_shapes(tuple_shapes) - else: - self._tuple_shapes = None - if shard_dimensions is not None: - self.set_shard_dimensions(shard_dimensions) - self._validate() - - def _validate(self): - """Checks that the configuration is self-consistent. - - Raises: - ValueError: if the shapes and sharding policies don't match. - """ - if self.tuple_shapes is not None: - for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes): - # Raise an error if the policy is incompatible with the shape. - _ = policy.get_sharded_shape(shape) - - @property - def number_of_tuple_elements(self): - """Returns the number of InfeedQueue tuple elements.""" - return len(self._sharding_policies) - - @property - def tuple_types(self): - """Returns the types of the InfeedQueue tuple elements.""" - return self._tuple_types - - def set_tuple_types(self, tuple_types): - """Sets the type of each element of the queue. - - tuple_types must be a list of length - self.number_of_tuple_elements, and each element must be - convertible to a dtype. - - Args: - tuple_types: the types of each queue element. - - Raises: - ValueError: if tuple_types is not of length - self.number_of_tuple_elements. - TypeError: if an element of tuple_types cannot be converted to a - dtype. - """ - if len(tuple_types) != self.number_of_tuple_elements: - raise ValueError("tuple_types is %s, but must be a list of length %d" % - (str(tuple_types), self.number_of_tuple_elements)) - if self._frozen: - for (frozen, updated) in zip(self._tuple_types, tuple_types): - if frozen != updated: - raise ValueError( - "Trying to update InfeedQueue with frozen configuration with an " - "incompatible type. Frozen types are %s, updated types are %s" % ( - str(self._tuple_types), str(tuple_types))) - else: - try: - self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types] - except (TypeError) as e: - raise TypeError( - "tuple_types is %s, but must be a list of elements each " - "convertible to dtype: got error %s" % (str(tuple_types), str(e))) - - @property - def tuple_shapes(self): - """Returns the shapes of the InfeedQueue tuple elements.""" - return self._tuple_shapes - - def set_tuple_shapes(self, tuple_shapes): - """Sets the shape of each element of the queue. - - tuple_shapes must be a list of length - self.number_of_tuple_elements, and each element must be - convertible to a TensorShape. - - Args: - tuple_shapes: the shapes of each queue element. - - Raises: - ValueError: if tuple_shapes is not of length - self.number_of_tuple_elements. - TypeError: if an element of tuple_shapes cannot be converted to - a TensorShape. - """ - if len(tuple_shapes) != self.number_of_tuple_elements: - raise ValueError("tuple_shapes is %s, but must be a list of length %d" % - (str(tuple_shapes), self.number_of_tuple_elements)) - try: - tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes] - except (ValueError, TypeError) as e: - raise TypeError( - "tuple_shapes is %s, but must be a list of elements each " - "convertible to TensorShape: got error %s" % (str(tuple_shapes), - str(e))) - if self._frozen: - for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes): - if frozen != updated: - raise ValueError( - "Trying to update InfeedQueue with frozen configuration with an " - "incompatible shape. Frozen shapes are %s, updated shapes are %s" - % (str(self._tuple_shapes), str(tuple_shapes))) - else: - self._tuple_shapes = tuple_shapes - self._validate() - - @property - def sharding_policies(self): - """Returns the sharding policies of the InfeedQueue tuple elements.""" - return self._sharding_policies - - @property - def shard_dimensions(self): - """Gets the shard dimension of each tuple element. - - Returns: - A list of length number_of_tuple_elements, where each list entry - is the shard dimension of that tuple element or None if the - shard dimension has not been set. - """ - # The number of shards is always the same for all the policies. - return [policy.shard_dimension for policy in self._sharding_policies] - - def set_shard_dimensions(self, shard_dimensions): - """Sets the shard_dimension of each element of the queue. - - shard_dimensions must be a list of length - self.number_of_tuple_elements, and each element must be - convertible to a Dimension compatible with self.tuple_shapes. - - Args: - shard_dimensions: the dimensions of each queue element. - - Raises: - ValueError: if shard_dimensions is not of length - self.number_of_tuple_elements; or an element of - shard_dimensions cannot be converted to a Dimension; or an - element of shard_dimensions is a Dimension that is out of - range for the corresponding tuple element shape. - """ - if len(shard_dimensions) != self.number_of_tuple_elements: - raise ValueError("shard_dimensions is %s, but must be a list of length %d" - % (str(shard_dimensions), - self.number_of_tuple_elements)) - for (policy, dimension) in zip(self._sharding_policies, shard_dimensions): - policy.set_shard_dimension(dimension) - self._validate() - - @property - def number_of_shards(self): - """Gets the number of shards to use for the InfeedQueue. - - Returns: - Number of shards or None if the number of shards has not been set. - """ - # The number of shards is always the same for all the policies. - return self._sharding_policies[0].number_of_shards - - def set_number_of_shards(self, number_of_shards): - """Sets the number of shards to use for the InfeedQueue. - - Args: - number_of_shards: number of ways to shard the InfeedQueue. - - Raises: - ValueError: if number_of_shards is not > 0; or the policies have - been frozen and number_of_shards was already set to something - else. - """ - for policy in self._sharding_policies: - policy.set_number_of_shards(number_of_shards) - self._validate() - - def set_configuration_from_input_tensors(self, input_tensors): - """Sets the shapes and types of the queue tuple elements. - - input_tensors is a list of Tensors whose types and shapes are used - to set the queue configuration. - - Args: - input_tensors: list of Tensors of the same types and shapes as - the desired queue Tuple. - - Raises: - ValueError: if input_tensors is not a list of length - self.number_of_tuple_elements - """ - if len(input_tensors) != self.number_of_tuple_elements: - raise ValueError("input_tensors is %s, but should be a list of %d Tensors" - % (str(input_tensors), self.number_of_tuple_elements)) - self.set_tuple_shapes([t.shape for t in input_tensors]) - self.set_tuple_types([t.dtype for t in input_tensors]) - - def set_configuration_from_sharded_input_tensors(self, input_tensors): - """Sets the shapes and types of the queue tuple elements. - - input_tensors is a list of lists of Tensors whose types and shapes are used - to set the queue configuration. The length of the outer list is the number - of shards required, and each inner list is the tuple of Tensors to use to - determine the types and shapes of the corresponding shard. This method - depends on the shard dimension, and calling it freezes the shard policy. - - Args: - input_tensors: list of lists of Tensors. The outer list length corresponds - to the desired number of shards, and each inner list is the size - and shape of the desired configuration of the corresponding shard. - - Raises: - ValueError: if any inner list is not a list of length - self.number_of_tuple_elements; or the inner lists do not combine to - form a consistent unsharded shape. - TypeError: if the types of the Tensors in the inner lists do not match. - """ - if not self._frozen: - # Unset the tuple shapes in case the configuration becomes - # transiently inconsistent. - self._tuple_shapes = None - number_of_shards = len(input_tensors) - self.set_number_of_shards(number_of_shards) - for t in input_tensors: - if len(t) != self.number_of_tuple_elements: - raise ValueError( - "input_tensors is %s but must be a list of lists, where each inner" - " list has length number_of_tuple_elements=%d" % ( - str(input_tensors), self.number_of_tuple_elements)) - # Transpose the inputs to make a list of shard shapes for each tuple - # element. - sharded_shapes = [[t[i].shape for t in input_tensors] - for i in xrange(self.number_of_tuple_elements)] - # For each tuple, get the unsharded shape using that tuple's policy. - unsharded_shapes = [ - policy.get_unsharded_shape(s) - for (policy, s) in zip(self._sharding_policies, sharded_shapes) - ] - self.set_tuple_shapes(unsharded_shapes) - for i in xrange(1, self.number_of_shards): - for (t1, t2) in zip(input_tensors[0], input_tensors[i]): - if t1.dtype != t2.dtype: - raise TypeError( - "types of the tuple elements of input_tensors %s are not " - "consistent" % str(input_tensors)) - self.set_tuple_types([t.dtype for t in input_tensors[0]]) - - def freeze(self): - """Freezes the InfeedQueue so it can no longer be modified. - - The configuration is implicitly frozen before any host-side or - device-side Ops are generated. The configuration cannot be frozen - until the types and shapes of the tuple elements have been set. - - Raises: - ValueError: if the types or shapes of the tuple elements have not been - set. - """ - self._frozen = True - if self._tuple_types is None: - raise ValueError( - "Can't freeze an InfeedQueue without setting all tuple types.") - if self._tuple_shapes is None: - raise ValueError( - "Can't freeze an InfeedQueue without setting all tuple shapes.") - for shape in self._tuple_shapes: - if shape.dims is None: - raise ValueError( - "Can't freeze an InfeedQueue without setting all tuple shapes.") - for policy in self._sharding_policies: - policy.freeze() - self._validate() - - def generate_dequeue_op(self, tpu_device=0): - """Generates the device-side Op to dequeue a tuple from the queue. - - Implicitly freezes the queue configuration if it is not already - frozen, which will raise errors if the shapes and types have not - been fully specified. - - Args: - tpu_device: The TPU device ordinal where the infeed instruction should be - placed. If None, no explicit placement will be performed, and it is up - to the user to call this API from within a proper TPU device scope. - The XLA code will fail if the TPU dequeue instruction is not bound to - any device. - - Returns: - A list of Outputs corresponding to a shard of infeed dequeued - into XLA, suitable for use within a replicated block. - - Raises: - ValueError: if the types or shapes of the tuple elements have not been - set; or if a dequeue op has already been generated. - """ - self.freeze() - if self._generated_dequeue_op: - raise ValueError("Can't generate two dequeue Ops from the same queue") - self._generated_dequeue_op = True - full_name = "%s/dequeue" % self._name - sharded_shapes = [ - policy.get_sharded_shape(shape) - for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) - ] - if tpu_device is not None: - with ops.device(tpu.core(tpu_device)): - return tpu_ops.infeed_dequeue_tuple( - dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) - else: - return tpu_ops.infeed_dequeue_tuple( - dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) - - def _generate_enqueue_op(self, - inputs, - name_prefix, - index, - device=None, - tpu_ordinal=-1): - """Generate a host-side Op to enqueue a tuple to the queue. - - If device is None the inputs are all required to have the same - device specification, and the enqueue Op is colocated with - inputs[0]. Otherwise the enqueue Op is placed on 'device'. - - Args: - inputs: a list of Tensors with the types and shapes of the tuple elements. - name_prefix: the base name for the Op. - index: the shard index, used to uniquify the Op name. - device: device to place the Op on, or None if it should be - colocated with the inputs. - tpu_ordinal: ordinal of the TPU device on the host to use for - infeed if device is a CPU device. Should be set to -1 if device - is a TPU device. - - Returns: - An Op corresponding to a shard of infeed enqueued at the host, - suitable for use within a replicated block. - - Raises: - ValueError: if device is None and inputs do not all have the - same device specification. - """ - full_name = "%s/%d" % (name_prefix, index) - shapes = [t.shape for t in inputs] - if device is None: - devices = [t.device for t in inputs] - for i in xrange(1, self.number_of_tuple_elements): - if devices[0] != devices[i]: - raise ValueError( - "input devices for shard %d are %s, but should all be the same" % - (index, str(devices))) - with ops.colocate_with(inputs[0]): - return tpu_ops.infeed_enqueue_tuple( - inputs=inputs, - shapes=shapes, - name=full_name, - device_ordinal=tpu_ordinal) - else: - with ops.device(device): - return tpu_ops.infeed_enqueue_tuple( - inputs=inputs, - shapes=shapes, - name=full_name, - device_ordinal=tpu_ordinal) - - def generate_enqueue_ops(self, - sharded_inputs, - tpu_ordinal_function=None, - placement_function=None): - """Generates the host-side Ops to enqueue the shards of a tuple. - - sharded_inputs is a list, one for each shard, of lists of - Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed - shard 0 if the queue. Returns the host-side Ops that must be run to - enqueue the sharded tuple. The Op for shard i is colocated with the inputs - for shard i. - - Implicitly freezes the queue configuration if it is not already - frozen. If the configuration has already been frozen, and is not - compatible with the types and shapes of sharded_inputs, an error - will be raised. - - Args: - sharded_inputs: a list of lists of Tensors. The length of the outer list - determines the number of shards. Each inner list indicates the types - and shapes of the tuples in the corresponding shard. - tpu_ordinal_function: if not None, a function that takes the - shard index as input and returns the ordinal of the TPU device - the shard's infeed should be placed on. tpu_ordinal_function must be - set if the inputs are placed on CPU devices. - placement_function: if not None, a function that takes the shard index as - input and returns the host device where the enqueue op should be placed - on. - - Returns: - A list of host-side Ops, one for each shard, that when executed together - will enqueue a full-size element of infeed. - - Raises: - ValueError: if the queue configuration has previously been frozen and the - shapes of the elements of sharded_inputs are not compatible with the - frozen configuration; or if the shapes of the elements of sharded_inputs - don't form a consistent unsharded tuple; or if the elements of a tuple - have different device constraints. - TypeError: if the queue configuration has previously been frozen and the - types of the elements of sharded_inputs are not compatible with the - frozen configuration; or if the types of the elements of sharded_inputs - don't form a consistent unsharded tuple. - """ - self.set_configuration_from_sharded_input_tensors(sharded_inputs) - self.freeze() - if self._generated_enqueue_ops: - raise ValueError("Can't generate two enqueue Ops from the same queue") - self._generated_enqueue_ops = True - if tpu_ordinal_function is None: - tpu_ordinal_function = lambda index: -1 - name_prefix = "%s/enqueue" % self._name - return [ - self._generate_enqueue_op( - shard, - name_prefix, - index, - tpu_ordinal=tpu_ordinal_function(index), - device=placement_function(index) if placement_function else None) - for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) - ] - - # TODO(misard) Generalize this to the case of systems that don't - # have 8 devices per host, and figure out what to do with - # model-parallelism. - def _default_placement_function(self, index): - return "/task:%d/device:CPU:0" % (index / 8) - - def _default_ordinal_function(self, index): - return index % 8 - - # TODO(b/36470756) remove this from tutorials once we have a better story - # for automatic placement of input pipelines. - def split_inputs_and_generate_enqueue_ops(self, - inputs, - device_assignment=None, - placement_function=None, - tpu_ordinal_function=None): - """POORLY-PERFORMING ON MULTI-HOST SYSTEMS. - - Generates the host-side Ops to enqueue a tuple. - - This method performs poorly because it takes an entire input on a single - host, splits it, and distributes it to all of the cores. It is present only - to simplify tutorial examples. - - inputs is a list of Tensors to use to feed the queue. Each input is split - into self.number_of_shards shards. Returns an Op for each shard to enqueue - the shard. The Op for shard i is placed on device placement_function(i). - - Implicitly freezes the queue configuration if it is not already - frozen. If the configuration has already been frozen, and is not - compatible with the types and shapes of inputs, an error - will be raised. - - Args: - inputs: a list of Tensors which indicates the types and shapes of the - queue tuple. - device_assignment: if not `None`, a TPU `DeviceAssignment`. If - device_assignment is not `None`, but `placement_function` and - `ordinal_function` are None, then `device_assignment` will be used to - place infeeds on the first k TPU shards, where k is the number of shards - in the queue. If all three are `None`, then default placement and - ordinal functions are used. - placement_function: if not None, a function that takes the shard - index as input and returns a device string indicating which - device the shard's infeed should be placed on. If placement_function - and tpu_ordinal_function are None, inputs are sharded round-robin - across the devices in the system. - tpu_ordinal_function: if not None, a function that takes the - shard index as input and returns the ordinal of the TPU device - the shard's infeed should be placed on. If placement_function - and tpu_ordinal_function are None, inputs are sharded round-robin - across the devices in the system. - - Returns: - A list of host-side Ops, one for each shard, that when executed together - will enqueue a full-size element of infeed. - - Raises: - ValueError: if the queue configuration has previously been frozen and the - shapes of the elements of inputs are not compatible with the frozen - configuration. - TypeError: if the queue configuration has previously been frozen and the - types of the elements of inputs are not compatible with the frozen - configuration. - """ - if device_assignment is None: - if placement_function is None: - placement_function = self._default_placement_function - if tpu_ordinal_function is None: - tpu_ordinal_function = self._default_ordinal_function - else: - - def _placement_function_from_map(index): - return device_assignment.host_device(replica=index) - - def _ordinal_function_from_map(index): - return device_assignment.tpu_ordinal(replica=index) - - if placement_function is None: - placement_function = _placement_function_from_map - if tpu_ordinal_function is None: - tpu_ordinal_function = _ordinal_function_from_map - self.set_configuration_from_input_tensors(inputs) - self.freeze() - if self._generated_enqueue_ops: - raise ValueError("Can't generate two enqueue Ops from the same queue") - self._generated_enqueue_ops = True - split_name_prefix = "%s/split" % self._name - if self.number_of_shards == 1: - transposed_sharded_inputs = [[inp] for inp in inputs] - else: - - def split_fn(inp, num_shards, axis, name): - with ops.colocate_with(inp): - return array_ops.split(inp, num_shards, axis=axis, name=name) - - transposed_sharded_inputs = [ - split_fn( - inp, - self.number_of_shards, - axis=policy.shard_dimension, - name="%s/%d" % (split_name_prefix, index)) - for (inp, policy, index) in zip(inputs, self._sharding_policies, - xrange(self.number_of_tuple_elements)) - ] - sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs] - for i in xrange(self.number_of_shards)] - name_prefix = "%s/enqueue" % self._name - return [ - self._generate_enqueue_op( - shard, - name_prefix, - index, - device=placement_function(index), - tpu_ordinal=tpu_ordinal_function(index)) - for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) - ] - - -class _PartitionedInfeedQueue(InfeedQueue): - """A helper object to build a device infeed queue with input partition. - - Args: - number_of_tuple_elements: the number of Tensors fed atomically through the - queue, must be present unless it can be inferred from other arguments. - device_assignment: A TPU `DeviceAssignment` which is used to place all the - partitions to different TPU infeed queues. - host_id: The id of the host machine. - input_partition_dims: A nested list/tuple of integers. Each inner - list/tuple describes how to partition the corresponding input tensor. - tuple_types: If not None, a list of types of the elements of the queue. - tuple_shapes: If not None, a list of shapes of the elements of the queue. - name: The name of the queue. - """ - - def __init__(self, - number_of_tuple_elements, - device_assignment, - host_id, - input_partition_dims=None, - tuple_types=None, - tuple_shapes=None, - name=None): - super(_PartitionedInfeedQueue, self).__init__( - number_of_tuple_elements=number_of_tuple_elements, - tuple_types=tuple_types, - tuple_shapes=None, - shard_dimensions=None, - name="PartitionedInfeedQueue" if name is None else name) - self._input_partition_dims = input_partition_dims - self._host_id = host_id - self._device_assignment = device_assignment - - def generate_dequeue_op(self, tpu_device=0): - """Generate TPU dequeue ops. - - Args: - tpu_device: The TPU device ordinal where the infeed instruction should be - placed. - - Returns: - A list of Outputs corresponding to a partition of infeed dequeued - into XLA, suitable for use within a replicated block. - - Raises: - ValueError: if the types or shapes of the tuple elements have not been - set; or if a dequeue op has already been generated. - """ - self.freeze() - if self._generated_dequeue_op: - raise ValueError("Can't generate two dequeue Ops from the same queue") - self._generated_dequeue_op = True - full_name = "%s/dequeue" % self._name - sharded_shapes = [ - policy.get_sharded_shape(shape) - for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) - ] - with ops.device(tpu.core(tpu_device)): - values = tpu_ops.infeed_dequeue_tuple( - dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) - return self._tag_sharding_attribute_for_dequeued_tensors( - values, self._input_partition_dims) - - def generate_enqueue_ops(self, per_host_sharded_inputs): - """Generates the host-side Ops to enqueue the partitioned inputs. - - per_host_sharded_inputs is a list, one for each replica, of lists of - Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed - replica i. - sharded_inputs[i][j] is partitioned by self._input_partition_dims[j]. - - For example, if sharded_inputs[i][j] is a 2-D Tensor: - [[A, B, C, D], - [E ,F, G, H]] - self._input_partition_dims[j] is [2, 4]. - - sharded_inputs[i][j] will be partitioned and flattened into: - [A, B, C, D, E, F, G, H] and fed into the logical core ids: - [0, 1, 2, 3, 4, 5, 6, 7] respectively. - - Args: - per_host_sharded_inputs: a list of lists of Tensors. The length of the - outer list determines the number of shards. Each inner list indicates - the types and shapes of the tuples in the corresponding shard. - - Returns: - A list of host-side Ops, one for each shard, that when executed together - will enqueue a full-size element of infeed. - - Raises: - ValueError: if the queue configuration has previously been frozen and the - shapes of the elements of sharded_inputs are not compatible with the - frozen configuration; or if the shapes of the elements of sharded_inputs - don't form a consistent unsharded tuple; or if the elements of a tuple - have different device constraints; or if the partition dims are invalid. - TypeError: if the queue configuration has previously been frozen and the - types of the elements of sharded_inputs are not compatible with the - frozen configuration; or if the types of the elements of sharded_inputs - don't form a consistent unsharded tuple. - """ - self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs) - number_of_replicas_per_host = len(per_host_sharded_inputs) - number_of_tuple_elements = len(per_host_sharded_inputs[0]) - - assert len(self._input_partition_dims) == number_of_tuple_elements - per_host_enqueue_ops = [] - - for replica_index in range(number_of_replicas_per_host): - flattened_inputs = per_host_sharded_inputs[replica_index] - inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs, - self._input_partition_dims) - inputs_parted_iters = [ - iter(self._partition_or_replicate_on_host(x, dims)) for x, dims in - zip(per_host_sharded_inputs[replica_index], inputs_part_dims_flat) - ] - - for logical_core in xrange(self._device_assignment.num_cores_per_replica): - # Places different partitions to different logic cores. - replica_id = self._device_assignment.lookup_replicas( - self._host_id, logical_core)[replica_index] - ordinal = self._device_assignment.tpu_ordinal( - replica=replica_id, logical_core=logical_core) - infeed_inputs = [] - for it in inputs_parted_iters: - input_for_device = next(it, None) - if input_for_device is not None: - infeed_inputs.append(input_for_device) - - if infeed_inputs: - per_host_enqueue_ops.append( - tpu_ops.infeed_enqueue_tuple( - inputs=infeed_inputs, - shapes=[x.shape for x in infeed_inputs], - name="enqueue/replica_{0}/input_{1}".format( - replica_index, logical_core), - device_ordinal=ordinal)) - return per_host_enqueue_ops - - def _check_input_partition_dims(self, tensor, dims): - """Checks that input partition dims are valid for the `Tensor`. - - Args: - tensor: Input tensor for partitioning. - dims: 1-D np.array of the list of integer describes how to partition the - input tensor. - - Raises: - ValueError: If the tensor can't be partitioned by dims or the - num_cores_per_replica doesn't match the number of - partitions(dims.prod()). - """ - if (dims < 1).any(): - raise ValueError("All input partition dims must be >= 1.") - - # No partitioning, so don't perform further checks. - if dims.prod() == 1: - return - - if dims.prod() != self._device_assignment.num_cores_per_replica: - raise ValueError( - "The product of each input parition dim should equal to " - "num_cores_per_replica. (dim = {}, num_cores_per_replica " - "= {})".format(dims, self._device_assignment.num_cores_per_replica)) - if dims.shape[0] != tensor.shape.ndims: - raise ValueError( - "Input partition dims must have the same number of dimensions " - "as the `Tensor` to be partitioned. (tensor shape = {}, input " - "partition dims = {}).".format(tensor.shape.as_list(), dims)) - - tensor.shape.assert_is_fully_defined() - - def _partition_or_replicate_on_host(self, tensor, dims): - """Partitions or replicates the input tensor. - - The ops inside this function are placed on the host side. - - Args: - tensor: The input tensor which will be partioned or replicated. - dims: A list of integer describes how to partition the input tensor. - Returns: - An iterator of `Tensor`s or a list of partioned tensors. - """ - if dims is None: - return itertools.repeat(tensor) - dims = np.array(dims) - self._check_input_partition_dims(tensor, dims) - output = [tensor] - shape_list = np.array(tensor.shape.as_list()) - quotients, remainders = np.divmod(shape_list, dims) - for axis, (quotient, remainder, dim, original_size) in enumerate( - zip(quotients, remainders, dims, shape_list)): - if dim <= 1: - continue - if remainder > 0: - # For each dimension, when it cannot be evenly partitioned, XLA assumes - # tensors are partitioned in a greedy manner by using - # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims - # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] => - # [[(3, 4), (3, 4), (2, 4), (2, 2)], - # [(2, 4), (2, 4), (2, 4), (2, 2)]] - ceil_ratio = quotient + 1 - num_full_slots, left_over = np.divmod(original_size, ceil_ratio) - num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over] - if len(num_or_size_splits) < dim: - num_or_size_splits += [0] * (dim - len(num_or_size_splits)) - new_output = [] - for x in output: - new_output.append( - array_ops.split( - x, num_or_size_splits=num_or_size_splits, axis=axis)) - output = new_output - else: - output = [array_ops.split(x, dim, axis=axis) for x in output] - output = nest.flatten(output) - return output - - def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims): - """Tags appropriate XLA sharding attribute to the dequeued tensor. - - Args: - tensor: The dequeued tensor on TPU. - dims: A list of integer describes how the tensor is partitioned. - - Returns: - The same tensor with the xla_sharding attribute. - """ - if dims is None: - return xla_sharding.replicate(tensor) - elif np.prod(dims) == 1: - return xla_sharding.assign_device(tensor, 0) - else: - tile_assignment = np.arange(np.prod(dims)).reshape(dims) - return xla_sharding.tile( - tensor=tensor, - tile_assignment=tile_assignment) - - def _tag_sharding_attribute_for_dequeued_tensors(self, dequeues, dims): - """Tags appropriate XLA sharding attribute to the dequeued tensors. - - Args: - dequeues: A list of dequeued tensors on TPU. - dims: A list of integer describes how the tensor is partitioned. - - Returns: - The same dequeues with appropriate xla_sharding attribute. - """ - nest.assert_shallow_structure(dequeues, dims) - return nest.map_structure_up_to( - dequeues, self._tag_sharding_attribute_for_dequeued_tensor, dequeues, - dims) +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.tpu_feed import * +# used by tests +from tensorflow.python.tpu.tpu_feed import _PartitionedInfeedQueue +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_function.py b/tensorflow/contrib/tpu/python/tpu/tpu_function.py index 84d5967ea547f0c036f7c9aa936ac0c99c141304..f2755c6979c2e49dbc19b6800462949601811496 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_function.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_function.py @@ -1,57 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Helper library for functions used during TPU compilation.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib - - -class TpuContext(object): - """A context object holding state about the TPU computation being built.""" - - def __init__(self): - """Creates a new TpuContext.""" - self._number_of_shards = None - - @property - def number_of_shards(self): - return self._number_of_shards - - def set_number_of_shards(self, number_of_shards): - self._number_of_shards = number_of_shards - - -# The Tpu context holds the number of shards when a sharded computation is -# being built, or None if no computation is being built. -_current_tpu_context = TpuContext() - - -@contextlib.contextmanager -def tpu_shard_context(number_of_shards): - if _current_tpu_context.number_of_shards is not None: - raise NotImplementedError("tpu_shard_context cannot be nested.") - try: - _current_tpu_context.set_number_of_shards(number_of_shards) - yield - finally: - _current_tpu_context.set_number_of_shards(None) - - -def get_tpu_context(): - return _current_tpu_context +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_function import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py index 1e11de6421e360faf0b9ad573a84f9aecdf9c98f..ca58e78d7b342c7ca70400652d99092ccbecbbde 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py @@ -1,203 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Optimizer that implements cross-shard gradient reduction for TPU.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function - -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu_function -from tensorflow.python.framework import ops -from tensorflow.python.ops.losses import losses -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import optimizer - - -class CrossShardOptimizer(optimizer.Optimizer): - """An optimizer that averages gradients across TPU shards.""" - - def __init__(self, - opt, - reduction=losses.Reduction.MEAN, - name="CrossShardOptimizer", - group_assignment=None): - """Construct a new cross-shard optimizer. - - Args: - opt: An existing `Optimizer` to encapsulate. - reduction: The reduction to apply to the shard losses. - name: Optional name prefix for the operations created when applying - gradients. Defaults to "CrossShardOptimizer". - group_assignment: Optional 2d int32 lists with shape - [num_groups, num_replicas_per_group] which describles how to apply - optimizer to subgroups. - - Raises: - ValueError: If reduction is not a valid cross-shard reduction. - """ - if reduction not in (losses.Reduction.SUM, losses.Reduction.MEAN): - raise ValueError("Unsupported reduction: %s." % reduction) - - super(CrossShardOptimizer, self).__init__(False, name) - self._opt = opt - self._reduction = reduction - self._group_assignment = group_assignment - - def _verify_and_get_subgroup_size(self, group_assignment, num_shards): - """Verify group_assignment and get the subgroup size". - - Args: - group_assignment: list of group ids for applying the optimizer - to subgroups. - num_shards: The number of TPU shards. - - Returns: - The size of one subgroup in group_assignment. - - Raises: - ValueError: If group_assignment is invalid. - """ - if not group_assignment: - return None - if not (isinstance(group_assignment, list) and - all(isinstance(i, list) for i in group_assignment)): - raise ValueError("group_assignment must be a list of list. Got {}".format( - group_assignment)) - - replica_ids = set() - for g in group_assignment: - for i in g: - replica_ids.add(i) - - if set(range(num_shards)) != replica_ids: - raise ValueError("group_assignment must be a permutation of range({0})." - " Got group_assignment={1}".format( - num_shards, group_assignment)) - - subgroup_size_list = [len(group) for group in group_assignment] - if all(subgroup_size_list[0] == size for size in subgroup_size_list): - return subgroup_size_list[0] - else: - raise ValueError("The size of each subgroup in group_assignment must " - "be equal. Got group_assignment={}".format( - self._group_assignment)) - - def compute_gradients(self, loss, var_list=None, **kwargs): - """Compute gradients of "loss" for the variables in "var_list". - - This simply wraps the compute_gradients() from the real optimizer. The - gradients will be aggregated in the apply_gradients() so that user can - modify the gradients like clipping with per replica global norm if needed. - The global norm with aggregated gradients can be bad as one replica's huge - gradients can hurt the gradients from other replicas. - - Args: - loss: A Tensor containing the value to minimize. - var_list: Optional list or tuple of `tf.Variable` to update to minimize - `loss`. Defaults to the list of variables collected in the graph - under the key `GraphKey.TRAINABLE_VARIABLES`. - **kwargs: Keyword arguments for compute_gradients(). - - Returns: - A list of (gradient, variable) pairs. - - Raises: - ValueError: If not within a tpu_shard_context or group_assignment is - invalid. - """ - num_shards = tpu_function.get_tpu_context().number_of_shards - if num_shards is None: - logging.warning( - "CrossShardOptimizer should be used within a tpu_shard_context, but " - "got unset number_of_shards. Assuming 1.") - num_shards = 1 - - subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment, - num_shards) - - if num_shards > 1 and self._reduction == losses.Reduction.MEAN: - if self._group_assignment: - scale = 1.0 / subgroup_size - else: - scale = 1.0 / num_shards - loss *= scale - - return self._opt.compute_gradients(loss, var_list=var_list, **kwargs) - - def apply_gradients(self, grads_and_vars, global_step=None, name=None): - """Apply gradients to variables. - - Calls tpu_ops.cross_replica_sum() to sum gradient contributions across - replicas, and then applies the real optimizer. - - Args: - grads_and_vars: List of (gradient, variable) pairs as returned by - compute_gradients(). - global_step: Optional Variable to increment by one after the - variables have been updated. - name: Optional name for the returned operation. Default to the - name passed to the Optimizer constructor. - - Returns: - An `Operation` that applies the gradients. If `global_step` was not None, - that operation also increments `global_step`. - - Raises: - ValueError: If the grads_and_vars is malformed. - """ - summed_grads_and_vars = [] - for (grad, var) in grads_and_vars: - if grad is None: - summed_grads_and_vars.append((grad, var)) - else: - with ops.colocate_with(grad): - summed_grads_and_vars.append((tpu_ops.cross_replica_sum( - grad, self._group_assignment), var)) - return self._opt.apply_gradients(summed_grads_and_vars, global_step, name) - - def get_slot(self, *args, **kwargs): - """Return a slot named "name" created for "var" by the Optimizer. - - This simply wraps the get_slot() from the actual optimizer. - - Args: - *args: Arguments for get_slot(). - **kwargs: Keyword arguments for get_slot(). - - Returns: - The `Variable` for the slot if it was created, `None` otherwise. - """ - return self._opt.get_slot(*args, **kwargs) - - def get_slot_names(self, *args, **kwargs): - """Return a list of the names of slots created by the `Optimizer`. - - This simply wraps the get_slot_names() from the actual optimizer. - - Args: - *args: Arguments for get_slot(). - **kwargs: Keyword arguments for get_slot(). - - Returns: - A list of strings. - """ - return self._opt.get_slot_names(*args, **kwargs) - - def variables(self): - """Forwarding the variables from the underlying optimizer.""" - return self._opt.variables() +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_optimizer import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py index f5af03f33ca8f13af517007672e9ce0e12be6205..93c52335a582e5fa83092f78212ca268079b7c12 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py @@ -1,253 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Helper library for sharding during TPU compilation.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.python.framework import tensor_shape - -_DEFAULT_NUMBER_OF_SHARDS = 1 -_DEFAULT_SHARD_DIMENSION = 0 - - -# TODO(b/36777903) change other parts of tpu.py to use this class. -class ShardingPolicy(object): - """An object use to hold the sharding policy for a Tensor. - """ - - def __init__(self): - self._number_of_shards = None - self._shard_dimension = None - self._frozen = False - - def __str__(self): - if self.number_of_shards is None or self.shard_dimension is None: - return "ShardingPolicy(unset)" - else: - return ("ShardingPolicy(%d shards dimension %d)" % - (self.number_of_shards, self.shard_dimension)) - - def _fill_default_values(self): - if self._number_of_shards is None: - self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS - if self._shard_dimension is None: - self._shard_dimension = tensor_shape.as_dimension( - _DEFAULT_SHARD_DIMENSION) - - def freeze(self): - """Prevents further modification to the sharding policy. - - Any values that have not been set when freeze is called are set to - defaults. If the ShardingPolicy is already frozen, this is a NoOp. - """ - if not self._frozen: - self._fill_default_values() - self._frozen = True - - @property - def number_of_shards(self): - """Returns the number of shards in the policy or None if unspecified.""" - return self._number_of_shards - - def set_number_of_shards(self, number_of_shards): - """Sets the number of shards for the current policy. - - If the policy has been frozen then number_of_shards must match the - existing setting. - - Args: - number_of_shards: The number of shards to use in the policy. - - Raises: - ValueError: If the policy has been frozen and number_of_shards - differs from the frozen value; or number_of_shards <= 0. - """ - if self._frozen: - if self._number_of_shards != number_of_shards: - raise ValueError( - "Can't set sharding policy to use %d shards since it has been " - "frozen to use %d." % (number_of_shards, self._number_of_shards)) - else: - if number_of_shards > 0: - self._number_of_shards = number_of_shards - else: - raise ValueError( - "Can't set sharding policy to use %s shards; value must be >0", - str(number_of_shards)) - - @property - def shard_dimension(self): - """Returns the shard dimension of the policy or None if unspecified.""" - return self._shard_dimension - - def set_shard_dimension(self, shard_dimension): - """Sets the shard dimension for the current policy. - - If the policy has been frozen then shard_dimension must match the - existing setting. - - Args: - shard_dimension: The shard dimension to use in the policy. - - Raises: - ValueError: If the policy has been frozen and shard_dimension - differs from the frozen value, or shard_dimension can't be - interpreted as a Dimension. - """ - if self._frozen: - if self._shard_dimension != shard_dimension: - raise ValueError( - "Can't set shard dimension to %d since it has been frozen to " - "use %d." % (shard_dimension, self._shard_dimension)) - else: - self._shard_dimension = tensor_shape.as_dimension(shard_dimension) - - def merge(self, other): - """Merges the policy of another policy into the current policy. - - Args: - other: The policy to merge into this one. - - Raises: - ValueError: If this policy has been frozen and the merge conflicts with - the frozen policy. - """ - if other.number_of_shards is not None: - self.set_number_of_shards(other.number_of_shards) - if other.shard_dimension is not None: - self.set_shard_dimension(other.shard_dimension) - - def get_sharded_shape(self, shape, shard_index=None): - """Returns the shape of a shard of a full Tensor. - - When given the shape of a 'full-size' Tensor, returns the shape of - the sub-Tensor after it has been sharded. Freezes the policy if it - has not yet been frozen. - - Args: - shape: The shape of the full-size Tensor to be sharded. - shard_index: The index of the shard whose shape should be returned. - shard_index can be None for sharding policies that use the same - shape for every shard. - freeze_config: - - Returns: - The shape of the sharded version of the Tensor. - - Raises: - ValueError: If shard_index is None when shards are of different - shapes; or shard_index is not None and - !(0<=shard_index= self.number_of_shards: - raise ValueError("shard_index %d, but must be in [0,%d)." % - (shard_index, self._number_of_shards)) - shape = tensor_shape.as_shape(shape) - if self._number_of_shards == 1: - # Don't do anything when there's only one shard. - return shape - ndims = shape.ndims - if ndims is None: - raise ValueError("shape must be a specified shape not Unknown") - if ndims <= self._shard_dimension: - raise ValueError("shape %s does not contain shard_dimension %d" % - (shape.as_list(), self._shard_dimension)) - dims = shape.as_list() - if dims[self._shard_dimension] is None: - raise ValueError("shape %s must have a fixed size for dimension %d " - "that is known at graph construction time." % - (shape.as_list(), self._shard_dimension)) - if (dims[self._shard_dimension] % self._number_of_shards) != 0: - raise ValueError("shape %s cannot be sharded %d ways along dimension %d" % - (shape.as_list(), self._number_of_shards, - self._shard_dimension)) - dims[self._shard_dimension] /= self._number_of_shards - return tensor_shape.as_shape(dims) - - def _unshard_shape(self, shape): - """Return the unsharded shape that would generate a given sharded shape. - - Args: - shape: the sharded shape to unshard - - Returns: - The unsharded shape. - - Raises: - ValueError: if shape is unknown or does not contain - self.shard_dimension - TypeError: if shape is not convertible to a TensorShape - """ - shape = tensor_shape.as_shape(shape) - if self._number_of_shards == 1: - # Don't do anything when there's only one shard. - return shape - ndims = shape.ndims - if ndims is None: - raise ValueError("shape must be a specified shape not Unknown") - if ndims <= self._shard_dimension: - raise ValueError("shape %s does not contain shard_dimension %d" % - (shape.as_list(), self._shard_dimension)) - dims = shape.as_list() - dims[self._shard_dimension] *= self._number_of_shards - return tensor_shape.as_shape(dims) - - def get_unsharded_shape(self, shapes): - """Returns the shape of an unsharded Tensor given a list of shards. - - When given a list of shapes of shards, returns the shape of the - unsharded Tensor that would generate the shards. Sets defaults for the - policy if number_of_shards or shard_dimension is None. - - Args: - shapes: The shapes of the Tensor shards to be combined. - - Returns: - The shape of the unsharded version of the Tensor. - - Raises: - ValueError: if shapes is not a list of length - self.number_of_shards; or any element of shapes is not a valid - shape consistent with the sharding policy; or the list of - shapes is not a valid sharding of a full shape. - TypeError: if an element of shapes is not convertible to a - TensorShape - """ - self._fill_default_values() - if len(shapes) != self.number_of_shards: - raise ValueError( - "shapes is %s but must be a list of length number_of_shards=%d" % ( - str(shapes), self.number_of_shards)) - unsharded_shapes = [self._unshard_shape(s) for s in shapes] - for i in xrange(self.number_of_shards - 1): - if not unsharded_shapes[i].is_compatible_with( - unsharded_shapes[self.number_of_shards - 1]): - raise ValueError( - "sharded shapes %s are not consistent shards of a full shape " - "sharded %d ways along dimension %d" % ( - str(shapes), self.number_of_shards, self.shard_dimension)) - return unsharded_shapes[0] +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.tpu_sharding import * +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py index d66ecfcf4a56b8da1c2d2f518bebe4baa76b315e..258d34ddaf5250e49c5a354caf018e4b64abae62 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -1,156 +1,25 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""TPU system metadata and associated tooling.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import re - -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.platform import tf_logging as logging - -_PINGING_MASTER_TIMEOUT_IN_MS = 60 * 1000 # 1 min -_RETRY_TIMES = 120 -_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins - -_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$') - -# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration, -# including num_cores and num_hosts. -_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [ - 'num_cores', - 'num_hosts', - 'num_of_cores_per_host', - 'topology', - 'devices', -]) - - -def _query_tpu_system_metadata(master_address, cluster_def=None, - query_topology=False): - """Automatically detects the TPU system metadata in the system.""" - tpu_core_count = 0 - devices = [] - device_dict = collections.defaultdict(list) - - # TODO(b/120564445): Replace with standard library for retries. - retry_count = 1 - while True: - logging.info('Querying Tensorflow master (%s) for TPU system metadata.', - master_address) - try: - with ops.Graph().as_default(): - with session_lib.Session( - master_address, - config=get_session_config_with_timeout( - _PINGING_MASTER_TIMEOUT_IN_MS, - cluster_def)) as sess: - devices = sess.list_devices() - for device in devices: - match = _TPU_DEVICE_REG.match(device.name) - if match: - host_id = match.group(1) - core_id = match.group(2) - device_dict[host_id].append(core_id) - tpu_core_count += 1 - break - except errors.DeadlineExceededError: - msg = ('Failed to connect to the Tensorflow master. The TPU worker may ' - 'not be ready (still scheduling) or the Tensorflow master address ' - 'is incorrect: got (%s).' % - (master_address)) - - # TODO(xiejw): For local or grpc master we might not need retry logic - # here. - if retry_count <= _RETRY_TIMES: - logging.warning('%s', msg) - logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES) - retry_count += 1 - else: - raise ValueError(msg) - - num_of_cores_per_host = 0 - if tpu_core_count: - num_cores_per_host_set = set( - [len(core_ids) for core_ids in device_dict.values()]) - if len(num_cores_per_host_set) != 1: - raise RuntimeError( - 'TPU cores on each host is not same. This should not happen!. ' - 'devices: {}'.format(devices)) - num_of_cores_per_host = num_cores_per_host_set.pop() - - topology = None - if query_topology: - if not tpu_core_count: - raise RuntimeError( - 'Cannot find any TPU cores in the system (master address {}). ' - 'This usually means the master address is incorrect or the ' - 'TPU worker has some problems. Available devices: {}'.format( - master_address, devices)) - - topology = _obtain_topology(master_address, cluster_def) - - metadata = _TPUSystemMetadata( - num_cores=tpu_core_count, - num_hosts=len(device_dict), - num_of_cores_per_host=num_of_cores_per_host, - topology=topology, - devices=devices) - - if tpu_core_count: - logging.info('Found TPU system:') - logging.info('*** Num TPU Cores: %d', metadata.num_cores) - logging.info('*** Num TPU Workers: %d', metadata.num_hosts) - logging.info('*** Num TPU Cores Per Worker: %d', - metadata.num_of_cores_per_host) - for device in metadata.devices: - logging.info('*** Available Device: %s', device) - else: - logging.info('Failed to find TPU: %s', metadata) - return metadata - - -def _obtain_topology(master_address, cluster_def): - """Obtains TPU fabric topology.""" - try: - logging.info('Initializing TPU system (master: %s) to fetch topology ' - 'for model parallelism. This might take a while.', - master_address) - with ops.Graph().as_default(): - session_config = get_session_config_with_timeout( - _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def) - with session_lib.Session( - master_address, config=session_config) as sess: - topology = sess.run(tpu.initialize_system()) - return topology - except errors.DeadlineExceededError: - raise ValueError( - 'Fail to initialize TPU system with master (%s). ' - 'Please double check the TPU system is functional.' % ( - master_address)) - - -def get_session_config_with_timeout(timeout_in_secs, cluster_def): - """Returns a session given a timeout and a cluster configuration.""" - config = config_pb2.ConfigProto( - operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def) - return config +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_system_metadata import * +# used by tests +from tensorflow.python.tpu.tpu_system_metadata import _query_tpu_system_metadata +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/training_loop.py b/tensorflow/contrib/tpu/python/tpu/training_loop.py index 0187b4bec6ecc55943bf48b9268a74e18ea5b488..673359b232d6857d468723873c449cb3e48168c7 100644 --- a/tensorflow/contrib/tpu/python/tpu/training_loop.py +++ b/tensorflow/contrib/tpu/python/tpu/training_loop.py @@ -1,214 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Library for constructing a training loop, suitable for TPUs.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.compiler import xla -from tensorflow.contrib.tpu.python.tpu import tpu_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops - - -def while_loop(condition, body, inputs=None, infeed_queue=None, name=None): - """Builds a training loop for TPUs. - - The set of loop-carried tensors corresponds to `inputs`. Both - `condition` and `body` take the current value of the loop-carried - tensors. 'body' additionally takes a tuple of infeed from - infeed_queue if infeed_queue is not None. `condition` must return a - single boolean value that determines whether iteration - continues. `body` must return an updated list of values for the - loop-carried tensors. - - Args: - condition: a Python function that builds the loop condition. - body: a Python function that builds the loop body. - inputs: a list of initial values passed into the training loop, or - None (equivalent to an empty list). - infeed_queue: if not None, the infeed queue from which to append a tuple - of arguments as inputs to condition. - name: (Deprecated) Does nothing. - - Returns: - The final values of the loop-carried tensors. - - Raises: - TypeError: if body or condition has the wrong signature. - """ - del name - # Converts inputs to Tensors. - inputs = [] if inputs is None else [ops.convert_to_tensor(x) for - x in inputs] - input_types = [x.dtype for x in inputs] - input_arity = len(inputs) - - body_arg_error = xla.check_function_argument_count( - body, input_arity, infeed_queue) - if body_arg_error is not None: - if infeed_queue is None: - raise TypeError( - "Supplied loop body function cannot be called with the specified " - "inputs. You specified %d inputs: %s, but the loop body needs %s" % ( - input_arity, str([i.name for i in inputs]), body_arg_error)) - else: - raise TypeError( - "Supplied loop body function cannot be called with the specified " - "inputs. You specified %d inputs: %s and %d additional inputs from " - "infeed, but the computation needs %s" % (input_arity, str( - [i.name for i in inputs]), infeed_queue.number_of_tuple_elements, - body_arg_error)) - condition_arg_error = xla.check_function_argument_count( - condition, input_arity, None) - if condition_arg_error is not None: - if infeed_queue is None: - raise TypeError( - "Supplied loop condition function cannot be called with the " - "specified inputs. You specified %d inputs: %s, but the loop " - "condition needs %s" % (input_arity, str([i.name for i in inputs]), - condition_arg_error)) - else: - raise TypeError( - "Supplied loop condition function cannot be called with the " - "specified inputs. You specified %d inputs: %s, but the loop " - "condition needs %s. Note that infeed is not passed to the loop " - "condition." % (input_arity, str([i.name for i in inputs]), - condition_arg_error)) - - def condition_wrapper(*inputs): - # Discards the dummy output added for arity-0 loops. - if input_arity == 0: - inputs = [] - return condition(*inputs) - - def body_wrapper(*inputs): - """Wrapper around `body` that handles infeed queues and control deps.""" - inputs = list(inputs) - - # Discards the dummy output added for arity-0 loops. - if input_arity == 0: - inputs = [] - - # Runs `body` with the dequeue_ops appended. - if infeed_queue: - number_of_shards = tpu_function.get_tpu_context().number_of_shards - if number_of_shards is None: - raise ValueError("Can't build training loop with infeed when there is " - "no tpu_shard_context. Are you building a loop or " - "graph directly rather than from inside tpu.rewrite, " - "tpu.batch_parallel, tpu.shard, or tpu.replicate?") - infeed_queue.set_number_of_shards(number_of_shards) - dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()] - else: - dequeue_ops = [] - outputs = body(*(inputs + dequeue_ops)) - - # If the computation only returned one value, make it a tuple. - if not isinstance(outputs, (list, tuple)): - outputs = (outputs,) - - outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) - for o in outputs - ] - - # Separates the returned Operations and Tensors. - output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs - if not isinstance(o, ops.Operation)] - - if outputs != output_tensors + output_operations: - raise ValueError( - "TPU training loop body must return zero or more Tensor values " - "followed by zero or more Operations.") - - output_types = [op.dtype for op in output_tensors] - if input_types != output_types: - raise TypeError( - "Mismatch between input types and output types for training loop " - "body: {} vs {}".format(input_types, output_types)) - - # Add the dequeue operations to output_operations to ensure they are run - # by the loop, even if the programmer's loop body does not use them. - output_operations += dequeue_ops - - # Add a dummy output, if needed. - if not output_tensors: - output_tensors = array_ops.constant(0) - - if output_operations: - # TODO(phawkins): in principle this is too restrictive since it serializes - # the training loop steps. In practice it does not matter since this loop - # will be compiled by XLA. - return control_flow_ops.tuple(output_tensors, - control_inputs=output_operations) - else: - return output_tensors - - # If the body has arity 0, add a dummy loop-carried value to which we can add - # control dependencies from any side-effecting operations. - if input_arity == 0: - inputs = [array_ops.constant(0)] - return control_flow_ops.while_loop( - condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1) - - -def repeat(n, body, inputs=None, infeed_queue=None, name=None): - """Builds a training loop that executes a fixed number of iterations. - - The set of loop-carried tensors correspond to `inputs`. - `body` must be a function that takes and returns the values of the - loop-carried tensors. - - Args: - n: the number of loop iterations - body: a Python function that builds the loop body. - inputs: a list of initial values passed into the training loop or - None (equivalent to an empty list). - infeed_queue: if not None, the infeed queue from which to append a tuple - of arguments as inputs to condition. - name: (Deprecated) Does nothing. - Returns: - The final values of the loop-carried tensors. - Raises: - ValueError: if there is a type error. - """ - def _convert_to_list(xs): - if not isinstance(xs, (list, tuple)): - return [xs] - else: - return list(xs) - - def cond(i, *args): - del args - return i < n - - def body_wrapper(i, *args): - return [i + 1] + _convert_to_list(body(*args)) - - inputs = [0] if inputs is None else [0] + _convert_to_list(inputs) - outputs = while_loop( - cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name) - outputs = _convert_to_list(outputs) - if len(outputs) == 1: - # Returns the Op rather than an empty list. - return outputs[0].op - else: - return outputs[1:] +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.training_loop import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/util.py b/tensorflow/contrib/tpu/python/tpu/util.py index dfb8ce1d1821da05c853bb0d10b1db3a857ccb1b..8d9b70d46eb42c9a525eeafc51d07f0ad4241d52 100644 --- a/tensorflow/contrib/tpu/python/tpu/util.py +++ b/tensorflow/contrib/tpu/python/tpu/util.py @@ -1,51 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== - -"""Utilities for the functionalities.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import time -import six - -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import training - -def check_positive_integer(value, name): - """Checks whether `value` is a positive integer.""" - if not isinstance(value, six.integer_types): - raise TypeError('{} must be int, got {}'.format(name, type(value))) - - if value <= 0: - raise ValueError('{} must be positive, got {}'.format(name, value)) - - -# TODO(b/118302029) Remove this copy of MultiHostDatasetInitializerHook after we -# release a tensorflow_estimator with MultiHostDatasetInitializerHook in -# python/estimator/util.py. -class MultiHostDatasetInitializerHook(training.SessionRunHook): - """Creates a SessionRunHook that initializes all passed iterators.""" - - def __init__(self, dataset_initializers): - self._initializers = dataset_initializers - - def after_create_session(self, session, coord): - del coord - start = time.time() - session.run(self._initializers) - logging.info('Initialized dataset iterators in %d seconds', - time.time() - start) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.util import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index f6427ae05a20f253edf030eff0f860361616042b..5bc4c3b88efd641b6f17a54753a29b0603c2b98c 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -264,9 +264,9 @@ py_test( py_test( name = "training_test", - size = "large", + size = "medium", srcs = ["python/training/training_test.py"], - shard_count = 3, + shard_count = 8, srcs_version = "PY2AND3", tags = ["notsan"], deps = [ diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 3beb7bfe3048a8f0294f7e9149b5a07b5fcc7d17..27f0d9b2e38c433d4fb4573285ecb8c9946112e8 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -187,7 +187,7 @@ def _cast_to_type_if_compatible(name, param_type, value): return param_type(value) -def parse_values(values, type_map): +def parse_values(values, type_map, ignore_unknown=False): """Parses hyperparameter values from a string into a python map. `values` is a string containing comma-separated `name=value` pairs. @@ -233,6 +233,9 @@ def parse_values(values, type_map): type T if either V has type T, or V is a list of elements of type T. Hence, for a multidimensional parameter 'x' taking float values, 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. + ignore_unknown: Bool. Whether values that are missing a type in type_map + should be ignored. If set to True, a ValueError will not be raised for + unknown hyperparameter type. Returns: A python map mapping each name to either: @@ -260,6 +263,8 @@ def parse_values(values, type_map): m_dict = m.groupdict() name = m_dict['name'] if name not in type_map: + if ignore_unknown: + continue raise ValueError('Unknown hyperparameter type for %s' % name) type_ = type_map[name] @@ -494,6 +499,7 @@ class HParams(object): value: New value of the hyperparameter. Raises: + KeyError: If the hyperparameter doesn't exist. ValueError: If there is a type mismatch. """ param_type, is_list = self._hparam_types[name] @@ -512,6 +518,8 @@ class HParams(object): def del_hparam(self, name): """Removes the hyperparameter with key 'name'. + Does nothing if it isn't present. + Args: name: Name of the hyperparameter. """ @@ -520,19 +528,20 @@ class HParams(object): del self._hparam_types[name] def parse(self, values): - """Override hyperparameter values, parsing new values from a string. + """Override existing hyperparameter values, parsing new values from a string. See parse_values for more detail on the allowed format for values. Args: - values: String. Comma separated list of `name=value` pairs where - 'value' must follow the syntax described above. + values: String. Comma separated list of `name=value` pairs where 'value' + must follow the syntax described above. Returns: The `HParams` instance. Raises: - ValueError: If `values` cannot be parsed. + ValueError: If `values` cannot be parsed or a hyperparameter in `values` + doesn't exist. """ type_map = dict() for name, t in self._hparam_types.items(): @@ -543,7 +552,7 @@ class HParams(object): return self.override_from_dict(values_map) def override_from_dict(self, values_dict): - """Override hyperparameter values, parsing new values from a dictionary. + """Override existing hyperparameter values, parsing new values from a dictionary. Args: values_dict: Dictionary of name:value pairs. @@ -552,6 +561,7 @@ class HParams(object): The `HParams` instance. Raises: + KeyError: If a hyperparameter in `values_dict` doesn't exist. ValueError: If `values_dict` cannot be parsed. """ for name, value in values_dict.items(): @@ -591,7 +601,7 @@ class HParams(object): sort_keys=sort_keys) def parse_json(self, values_json): - """Override hyperparameter values, parsing new values from a json object. + """Override existing hyperparameter values, parsing new values from a json object. Args: values_json: String containing a json object of name:value pairs. @@ -600,6 +610,7 @@ class HParams(object): The `HParams` instance. Raises: + KeyError: If a hyperparameter in `values_json` doesn't exist. ValueError: If `values_json` cannot be parsed. """ values_map = json.loads(values_json) diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 660c97f25e8458c345c8914bcaf98f37d047e50e..a990e04711ce68bd928a508484f0d6f657dd2f8c 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -216,6 +216,14 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['arr'], dict)) self.assertDictEqual(parse_dict['arr'], {1: 10}) + def testParseValuesWithIndexAssigment1_IgnoreUnknown(self): + """Assignment to an index position.""" + parse_dict = hparam.parse_values( + 'arr[1]=10,b=5', {'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 1) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {1: 10}) + def testParseValuesWithIndexAssigment2(self): """Assignment to multiple index positions.""" parse_dict = hparam.parse_values('arr[0]=10,arr[5]=20', {'arr': int}) @@ -223,6 +231,14 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['arr'], dict)) self.assertDictEqual(parse_dict['arr'], {0: 10, 5: 20}) + def testParseValuesWithIndexAssigment2_IgnoreUnknown(self): + """Assignment to multiple index positions.""" + parse_dict = hparam.parse_values( + 'arr[0]=10,arr[5]=20,foo=bar', {'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 1) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {0: 10, 5: 20}) + def testParseValuesWithIndexAssigment3(self): """Assignment to index positions in multiple names.""" parse_dict = hparam.parse_values('arr[0]=10,arr[1]=20,L[5]=100,L[10]=200', @@ -234,6 +250,17 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['L'], dict)) self.assertDictEqual(parse_dict['L'], {5: 100, 10: 200}) + def testParseValuesWithIndexAssigment3_IgnoreUnknown(self): + """Assignment to index positions in multiple names.""" + parse_dict = hparam.parse_values( + 'arr[0]=10,C=5,arr[1]=20,B[0]=kkk,L[5]=100,L[10]=200', + {'arr': int, 'L': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 2) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {0: 10, 1: 20}) + self.assertTrue(isinstance(parse_dict['L'], dict)) + self.assertDictEqual(parse_dict['L'], {5: 100, 10: 200}) + def testParseValuesWithIndexAssigment4(self): """Assignment of index positions and scalars.""" parse_dict = hparam.parse_values('x=10,arr[1]=20,y=30', @@ -246,6 +273,17 @@ class HParamsTest(test.TestCase): self.assertEqual(parse_dict['x'], 10) self.assertEqual(parse_dict['y'], 30) + def testParseValuesWithIndexAssigment4_IgnoreUnknown(self): + """Assignment of index positions and scalars.""" + parse_dict = hparam.parse_values( + 'x=10,foo[0]=bar,arr[1]=20,zzz=78,y=30', + {'x': int, 'y': int, 'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 3) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {1: 20}) + self.assertEqual(parse_dict['x'], 10) + self.assertEqual(parse_dict['y'], 30) + def testParseValuesWithIndexAssigment5(self): """Different variable types.""" parse_dict = hparam.parse_values('a[0]=5,b[1]=true,c[2]=abc,d[3]=3.14', { @@ -264,24 +302,55 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['d'], dict)) self.assertDictEqual(parse_dict['d'], {3: 3.14}) + def testParseValuesWithIndexAssigment5_IgnoreUnknown(self): + """Different variable types.""" + parse_dict = hparam.parse_values( + 'a[0]=5,cc=4,b[1]=true,c[2]=abc,mm=2,d[3]=3.14', + {'a': int, 'b': bool, 'c': str, 'd': float}, + ignore_unknown=True) + self.assertEqual(set(parse_dict.keys()), {'a', 'b', 'c', 'd'}) + self.assertTrue(isinstance(parse_dict['a'], dict)) + self.assertDictEqual(parse_dict['a'], {0: 5}) + self.assertTrue(isinstance(parse_dict['b'], dict)) + self.assertDictEqual(parse_dict['b'], {1: True}) + self.assertTrue(isinstance(parse_dict['c'], dict)) + self.assertDictEqual(parse_dict['c'], {2: 'abc'}) + self.assertTrue(isinstance(parse_dict['d'], dict)) + self.assertDictEqual(parse_dict['d'], {3: 3.14}) + def testParseValuesWithBadIndexAssigment1(self): """Reject assignment of list to variable type.""" with self.assertRaisesRegexp(ValueError, r'Assignment of a list to a list index.'): hparam.parse_values('arr[1]=[1,2,3]', {'arr': int}) + def testParseValuesWithBadIndexAssigment1_IgnoreUnknown(self): + """Reject assignment of list to variable type.""" + with self.assertRaisesRegexp(ValueError, + r'Assignment of a list to a list index.'): + hparam.parse_values( + 'arr[1]=[1,2,3],c=8', {'arr': int}, ignore_unknown=True) + def testParseValuesWithBadIndexAssigment2(self): """Reject if type missing.""" with self.assertRaisesRegexp(ValueError, r'Unknown hyperparameter type for arr'): hparam.parse_values('arr[1]=5', {}) + def testParseValuesWithBadIndexAssigment2_IgnoreUnknown(self): + """Ignore missing type.""" + hparam.parse_values('arr[1]=5', {}, ignore_unknown=True) + def testParseValuesWithBadIndexAssigment3(self): """Reject type of the form name[index].""" with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter type for arr'): hparam.parse_values('arr[1]=1', {'arr[1]': int}) + def testParseValuesWithBadIndexAssigment3_IgnoreUnknown(self): + """Ignore type of the form name[index].""" + hparam.parse_values('arr[1]=1', {'arr[1]': int}, ignore_unknown=True) + def testWithReusedVariables(self): with self.assertRaisesRegexp(ValueError, 'Multiple assignments to variable \'x\''): diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py index c272a2ac144068cfb7355c2647eebf5bd0ce9d50..4ceb6e9350f5167efc8f7266d4e748cc6fa4ffd6 100644 --- a/tensorflow/contrib/training/python/training/training.py +++ b/tensorflow/contrib/training/python/training/training.py @@ -244,7 +244,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -354,11 +353,11 @@ def multiply_gradients(grads_and_vars, gradient_multipliers): raise ValueError('Requested multiple of `None` gradient.') if isinstance(grad, ops.IndexedSlices): - tmp = grad.values * constant_op.constant( + tmp = grad.values * ops.convert_to_tensor( gradient_multipliers[key], dtype=grad.dtype) grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape) else: - grad *= constant_op.constant( + grad *= ops.convert_to_tensor( gradient_multipliers[key], dtype=grad.dtype) multiplied_grads_and_vars.append((grad, var)) return multiplied_grads_and_vars @@ -419,7 +418,7 @@ def create_train_op(total_loss, update_ops = set(update_ops) if not global_update_ops.issubset(update_ops): logging.warning('update_ops in create_train_op does not contain all the ' - ' update_ops in GraphKeys.UPDATE_OPS') + 'update_ops in GraphKeys.UPDATE_OPS') # Make sure update_ops are computed before total_loss. if update_ops: @@ -433,7 +432,7 @@ def create_train_op(total_loss, else: # Make sure that variables_to_train are in tf.trainable_variables() for v in variables_to_train: - assert v in tf_variables.trainable_variables() + assert v.trainable or v in tf_variables.trainable_variables() assert variables_to_train diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD index d9ccda8e89a4c9a1b3f3d24915b9ad3fb4d9be5f..ada08f95ae46ea06b3896ca3b1603277d62bf6fc 100644 --- a/tensorflow/contrib/util/BUILD +++ b/tensorflow/contrib/util/BUILD @@ -16,10 +16,15 @@ cc_library( srcs = ["convert_graphdef_memmapped_format_lib.cc"], hdrs = ["convert_graphdef_memmapped_format_lib.h"], deps = [ + "//tensorflow/core:array_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:sendrecv_ops_op_lib", "//tensorflow/core:tensorflow", "//tensorflow/core/kernels:immutable_constant_op", ], diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc index 2784bf124ceaacd8e01f0653287fa7f006d0d608..2f2375427862ad1e99a0e6bfc506382d200e9b1d 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_mgr.cc @@ -277,9 +277,18 @@ void RdmaMgr::InitAllocators() { ProcessState::singleton()->AddCPUFreeVisitor(free_visitor); #if GOOGLE_CUDA + GPUProcessState::singleton()->AddCUDAHostAllocVisitor(0, alloc_visitor); + GPUProcessState::singleton()->AddCUDAHostFreeVisitor(0, free_visitor); + if (IsGDRAvailable()) { // Note we don't free allocated GPU memory so there is no free visitor - int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1; + + // TODO: This is to fix the 'invalid use of member in static member function + // bug'. + // Waiting for better implementation. + // int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + // + 1; + int32_t bus_id = 0; SubAllocator::Visitor cuda_alloc_visitor = [](void* ptr, int gpu_id, size_t num_bytes) { @@ -288,9 +297,6 @@ void RdmaMgr::InitAllocators() { }; GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor); - GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id, - alloc_visitor); - GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor); LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id; } #endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc index 5b72b1604aca2e0c593978c6104322372788eb3c..d07fd5ae6e9cc0dbf67c6b6a4e8db086b4c74aa1 100644 --- a/tensorflow/contrib/verbs/verbs_server_lib.cc +++ b/tensorflow/contrib/verbs/verbs_server_lib.cc @@ -33,6 +33,8 @@ RendezvousMgrInterface* NewRdmaRendezvousMgr(const WorkerEnv* env) { return new RdmaRendezvousMgr(env); } +std::once_flag reg_mem_visitors_call; + } // namespace VerbsServer::VerbsServer(const ServerDef& server_def, Env* env) @@ -76,14 +78,13 @@ Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def, return Status::OK(); } -namespace { -std::once_flag reg_mem_visitors_call; -} // namespace - Status VerbsServer::Init(ServiceInitFunction service_func, RendezvousMgrCreationFunction rendezvous_mgr_func) { std::call_once(reg_mem_visitors_call, []() { RdmaMgr::RegMemVisitors(); }); - Status s = GrpcServer::Init(service_func, rendezvous_mgr_func); + GrpcServerOptions opts; + opts.service_func = service_func; + opts.rendezvous_mgr_func = rendezvous_mgr_func; + Status s = GrpcServer::Init(opts); { mutex_lock l(mu_); CHECK_EQ(verbs_state_, DISCONNECTED); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 66714235b535c14a8f13c40bb2a4df8d7494dc05..06c108b38fbf1d4b796c313ce700332803c73ef9 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -49,7 +49,7 @@ # filegroup ":android_proto_srcs" - Protos # filegroup ":android_srcs" - Core sources # cc_library ":android_tensorflow_lib" - Native library -# cc_library ":android_tensorflow_lib_selective_registration" - Native library +# cc_library ":android_tensorflow_lib_lite" - Native library, without ops, # supporting SELECTIVE_REGISTRATION feature. # portable_proto_library ":android_proto_lib" (Google-internal) # @@ -70,10 +70,14 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 +# Export the BUILD file so automated tooling can check licenses +exports_files(["BUILD"]) + load( "//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_android", + "if_emscripten", "if_ios", "if_linux_x86_64", "if_mobile", @@ -84,10 +88,12 @@ load( "tf_copts", "tf_cuda_library", "tf_features_nomodules_if_android", + "tf_features_nomodules_if_emscripten", "tf_gen_op_libs", "tf_generate_proto_text_sources", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_android", + "tf_opts_nortti_if_emscripten", "transitive_hdrs", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl") @@ -113,7 +119,6 @@ load( "tf_additional_device_tracer_test_flags", "tf_additional_gdr_lib_defines", "tf_additional_human_readable_json_deps", - "tf_additional_logger_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", "tf_additional_lib_hdrs", @@ -123,7 +128,6 @@ load( "tf_additional_libdevice_srcs", "tf_additional_minimal_lib_srcs", "tf_additional_mpi_lib_defines", - "tf_additional_proto_compiler_hdrs", "tf_additional_proto_hdrs", "tf_additional_proto_srcs", "tf_additional_test_deps", @@ -142,6 +146,7 @@ load( "tf_protos_grappler", "tf_protos_grappler_impl", "tf_pyclif_proto_library", + "tf_grpc_service_all", ) load( "//tensorflow/core:platform/default/build_config_root.bzl", @@ -179,7 +184,6 @@ COMMON_PROTO_SRCS = [ "framework/function.proto", "framework/graph.proto", "framework/graph_transfer_info.proto", - "framework/iterator.proto", "framework/kernel_def.proto", "framework/log_memory.proto", "framework/node_def.proto", @@ -200,10 +204,12 @@ COMMON_PROTO_SRCS = [ "protobuf/cluster.proto", "protobuf/debug.proto", "protobuf/device_properties.proto", + "protobuf/graph_debug_info.proto", "protobuf/queue_runner.proto", "protobuf/rewriter_config.proto", "protobuf/tensor_bundle.proto", "protobuf/saver.proto", + "protobuf/verifier_config.proto", "util/event.proto", "util/memmapped_file_system.proto", "util/saved_tensor_slice.proto", @@ -223,13 +229,15 @@ CORE_PROTO_SRCS = COMMON_PROTO_SRCS + ERROR_CODES_PROTO_SRCS # ones with individual proto_library targets. ADDITIONAL_CORE_PROTO_SRCS = [ "example/example_parser_configuration.proto", - "protobuf/checkpointable_object_graph.proto", + "protobuf/trackable_object_graph.proto", "protobuf/control_flow.proto", # TODO(ebrevdo): Re-enable once CriticalSection is in core. # "protobuf/critical_section.proto", "protobuf/meta_graph.proto", "protobuf/named_tensor.proto", "protobuf/saved_model.proto", + "protobuf/saved_object_graph.proto", + "protobuf/struct.proto", "protobuf/tensorflow_server.proto", "protobuf/transport_options.proto", "util/test_log.proto", @@ -412,9 +420,8 @@ cc_library( name = "platform_protobuf", srcs = tf_platform_hdrs([ "protobuf.h", - ]) + tf_platform_srcs([ - "protobuf.cc", ]) + [ + "platform/protobuf.cc", "platform/protobuf_util.cc", "lib/core/status.h", ], @@ -433,6 +440,17 @@ cc_library( ], ) +cc_library( + name = "grpc_services", + srcs = [], + hdrs = [ + "platform/grpc_services.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = tf_grpc_service_all(), +) + cc_library( name = "human_readable_json", srcs = tf_platform_srcs(["human_readable_json.cc"]), @@ -447,14 +465,11 @@ cc_library( cc_library( name = "logger", - srcs = tf_platform_srcs(["logger.cc"]), - hdrs = ["platform/logger.h"] + tf_platform_hdrs(["logger.h"]), + srcs = ["platform/logger.cc"], + hdrs = ["platform/logger.h"], copts = tf_copts(), visibility = ["//visibility:public"], - deps = [ - ":lib", - ":lib_internal", - ] + tf_additional_logger_deps(), + deps = [":lib_proto_parsing"], ) filegroup( @@ -505,6 +520,7 @@ cc_library( ":platform_port", ":platform_protobuf", "//tensorflow/core/platform/default/build_config:env", + "//tensorflow/core/platform/default/build_config:port", ], ) @@ -660,7 +676,7 @@ cc_library( name = "lib_proto_compiler", hdrs = [ "platform/protobuf_compiler.h", - ] + tf_additional_proto_compiler_hdrs(), + ], copts = tf_copts(), deps = tf_lib_proto_compiler_deps() + [ ":lib_proto_parsing", @@ -1018,6 +1034,7 @@ cc_library( ":lib", ":lib_internal", ":protos_all_cc", + "//tensorflow/core/util/proto:proto_utils", ], ) @@ -1044,13 +1061,13 @@ cc_library( "platform/default/integral_types.h", "platform/default/logging.h", "platform/default/mutex.h", - "platform/default/protobuf.h", "platform/default/thread_annotations.h", "platform/dynamic_annotations.h", "platform/macros.h", "platform/mutex.h", "platform/platform.h", "platform/prefetch.h", + "platform/protobuf.h", "platform/thread_annotations.h", "platform/types.h", "platform/cpu_info.h", @@ -1075,6 +1092,7 @@ tf_gen_op_libs( "tensor_forest_ops", "candidate_sampling_ops", "checkpoint_ops", + "clustering_ops", "collective_ops", "control_flow_ops", "ctc_ops", @@ -1100,6 +1118,7 @@ tf_gen_op_libs( "parsing_ops", "random_grad", "random_ops", + "stateful_random_ops", "remote_fused_graph_ops", "rpc_ops", "scoped_allocator_ops", @@ -1134,6 +1153,13 @@ tf_gen_op_libs( deps = [":protos_all_cc"], ) +tf_gen_op_libs( + op_lib_names = [ + "mkl_array_ops", + ], + deps = [":protos_all_cc"], +) + tf_gen_op_libs( op_lib_names = [ "audio_ops", @@ -1154,6 +1180,29 @@ tf_gen_op_libs( deps = [":lib"], ) +tf_gen_op_libs( + op_lib_names = [ + "tpu_configuration_ops", + "tpu_cross_replica_ops", + "tpu_embedding_ops", + "tpu_functional_ops", + "tpu_heartbeat_ops", + "tpu_host_compute_ops", + "tpu_infeed_ops", + "tpu_outfeed_ops", + "tpu_ordinal_selector_ops", + "tpu_replication_ops", + ], + deps = [ + ":lib", + ":lib_proto_parsing", + ":protos_all_cc", + "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", + "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", + "//tensorflow/core/tpu:tpu_embedding_output_layout_utils", + ], +) + # And one for all user ops cc_library( name = "user_ops_op_lib", @@ -1229,6 +1278,7 @@ cc_library( ":tensor_forest_ops_op_lib", ":candidate_sampling_ops_op_lib", ":checkpoint_ops_op_lib", + ":clustering_ops_op_lib", ":collective_ops_op_lib", ":control_flow_ops_op_lib", ":ctc_ops_op_lib", @@ -1254,6 +1304,7 @@ cc_library( ":parsing_ops_op_lib", ":ragged_ops", ":random_ops_op_lib", + ":stateful_random_ops_op_lib", ":remote_fused_graph_ops_op_lib", ":resource_variable_ops_op_lib", ":rpc_ops_op_lib", @@ -1268,10 +1319,23 @@ cc_library( ":state_ops_op_lib", ":stateless_random_ops_op_lib", ":string_ops_op_lib", + ":tpu_configuration_ops_op_lib", + ":tpu_cross_replica_ops_op_lib", + ":tpu_embedding_ops_op_lib", + ":tpu_functional_ops_op_lib", + ":tpu_heartbeat_ops_op_lib", + ":tpu_host_compute_ops_op_lib", + ":tpu_infeed_ops_op_lib", + ":tpu_outfeed_ops_op_lib", + ":tpu_ordinal_selector_ops_op_lib", + ":tpu_replication_ops_op_lib", ":training_ops_op_lib", ":user_ops_op_lib", ":word2vec_ops", - ] + if_mkl([":mkl_nn_ops_op_lib"]) + tf_additional_cloud_op_deps(), + ] + if_mkl([ + ":mkl_array_ops_op_lib", + ":mkl_nn_ops_op_lib", + ]) + tf_additional_cloud_op_deps(), alwayslink = 1, ) @@ -1372,8 +1436,8 @@ cc_library( # This includes implementations of all kernels built into TensorFlow. cc_library( - name = "all_kernels_statically_linked", - visibility = ["//visibility:private"], + name = "all_kernels_impl", + visibility = ["//tensorflow/core:__subpackages__"], deps = [ "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:audio", @@ -1383,12 +1447,12 @@ cc_library( "//tensorflow/core/kernels:tensor_forest_ops", "//tensorflow/core/kernels:candidate_sampler_ops", "//tensorflow/core/kernels:checkpoint_ops", + "//tensorflow/core/kernels:clustering_ops", "//tensorflow/core/kernels:collective_ops", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:ctc_ops", "//tensorflow/core/kernels:cudnn_rnn_kernels", "//tensorflow/core/kernels:data_flow", - "//tensorflow/core/kernels:dataset_ops", "//tensorflow/core/kernels:decode_proto_op", "//tensorflow/core/kernels:encode_proto_op", "//tensorflow/core/kernels:fake_quant_ops", @@ -1399,18 +1463,20 @@ cc_library( "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", "//tensorflow/core/kernels:linalg", - "//tensorflow/core/kernels:list_kernels", "//tensorflow/core/kernels:lookup", "//tensorflow/core/kernels:logging", "//tensorflow/core/kernels:manip", "//tensorflow/core/kernels:math", "//tensorflow/core/kernels:multinomial_op", + "//tensorflow/core/kernels:mutex_ops", "//tensorflow/core/kernels:nn", "//tensorflow/core/kernels:parameterized_truncated_normal_op", "//tensorflow/core/kernels:parsing", "//tensorflow/core/kernels:partitioned_function_ops", + "//tensorflow/core/kernels:pooling_ops", "//tensorflow/core/kernels:ragged_ops", "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:stateful_random_ops", "//tensorflow/core/kernels:random_poisson_op", "//tensorflow/core/kernels:remote_fused_graph_ops", "//tensorflow/core/kernels:required", @@ -1441,6 +1507,7 @@ cc_library( "//tensorflow/core/kernels:mkl_identity_op", "//tensorflow/core/kernels:mkl_input_conversion_op", "//tensorflow/core/kernels:mkl_lrn_op", + "//tensorflow/core/kernels:mkl_requantize_ops", "//tensorflow/core/kernels:mkl_pooling_ops", "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", @@ -1462,8 +1529,13 @@ cc_library( visibility = ["//visibility:public"], deps = if_dynamic_kernels( [], - otherwise = [":all_kernels_statically_linked"], - ), + otherwise = [":all_kernels_impl"], + ) + [ + # TODO(gunan): Work on the API between these and rest of TF and make + # these also dynamically loading. + "//tensorflow/core/kernels:dataset_ops", # Depends on grappler + "//tensorflow/core/kernels:list_kernels", # Depends on variant_op_registry.h + ], ) tf_cuda_library( @@ -1524,6 +1596,7 @@ cc_library( ":framework_internal", ":lib", ":lib_internal", + ":ops", ":protos_all_cc", ":shape_inference_testutil", ":tensor_testutil", @@ -1611,6 +1684,9 @@ filegroup( "**/*main.cc", "debug/**/*", "framework/op_gen_*", + "framework/node_def_util.*", + "framework/op_kernel.*", + "framework/dataset.*", "lib/jpeg/**/*", "lib/png/**/*", "lib/gif/**/*", @@ -1619,7 +1695,6 @@ filegroup( "util/reporter.*", "platform/**/cuda_libdevice_path.*", "platform/**/logger.cc", - "platform/**/logger.h", "platform/default/test_benchmark.*", "platform/cuda.h", "platform/google/**/*", @@ -1654,6 +1729,9 @@ filegroup( "common_runtime/**/*.cc", "graph/**/*.h", "graph/**/*.cc", + "framework/node_def_util.*", + "framework/op_kernel.*", + "framework/dataset.*", ], exclude = [ "**/*test.*", @@ -1743,11 +1821,35 @@ cc_library( cc_library( name = "mobile_additional_lib_deps", deps = tf_additional_lib_deps() + [ + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) +cc_library( + name = "emscripten_tensorflow_lib_lite_nortti_lite_protos_no_runtime", + srcs = if_emscripten(["//tensorflow/core:mobile_srcs_no_runtime"]), + copts = ["-DSUPPORT_SELECTIVE_REGISTRATION"] + tf_opts_nortti_if_emscripten(), + defines = ["TENSORFLOW_LITE_PROTOS"], + linkopts = ["-lz"], + tags = [ + "manual", + "notap", + ], + visibility = ["//visibility:public"], + deps = [ + ":emscripten_proto_lib_no_rtti_lite_runtime", + ":mobile_additional_lib_deps", + ":stats_calculator_portable", + "//third_party/eigen3", + "@double_conversion//:double-conversion", + "@nsync//:nsync_cpp", + "@zlib_archive//:zlib", + ], + alwayslink = 1, +) + # Native library support for iOS applications. # # bazel build --config=ios_x86_64 \ @@ -1829,27 +1931,6 @@ cc_library( alwayslink = 1, ) -# Android library for use with the SELECTIVE_REGISTRATION feature. -# Does not contain operators. In contrast to android_tensorflow_lib_lite, -# this links in framework support for all types, relying on selective -# registration of ops to prune code size. -# -# TODO(gonnet): Move all users of these aliases to the corresponding -# :android_tensorflow_lib_lite* targets and remove. -alias( - name = "android_tensorflow_lib_selective_registration", - actual = ":android_tensorflow_lib_lite", - visibility = ["//visibility:public"], -) - -# Android library for use with the SELECTIVE_REGISTRATION feature with -# no proto_rtti. -alias( - name = "android_tensorflow_lib_selective_registration_nortti", - actual = ":android_tensorflow_lib_lite_nortti", - visibility = ["//visibility:public"], -) - filegroup( name = "android_op_registrations_and_gradients", srcs = glob( @@ -1862,6 +1943,7 @@ filegroup( "**/*testutil*", "**/*testlib*", "**/*main.cc", + "**/tpu_*", ], ), visibility = ["//visibility:public"], @@ -1964,6 +2046,14 @@ cc_library( ], ) +cc_library( + name = "rocm", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core/platform/default/build_config:rocm", + ], +) + # ----------------------------------------------------------------------------- # Clif-related proto libraries. @@ -2023,6 +2113,13 @@ tf_pyclif_proto_library( visibility = ["//visibility:public"], ) +tf_pyclif_proto_library( + name = "framework/step_stats_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "framework/step_stats.proto", + visibility = ["//visibility:public"], +) + tf_pyclif_proto_library( name = "framework/types_pyclif", proto_lib = ":protos_all_cc", @@ -2200,6 +2297,7 @@ cc_library( ], }), deps = tf_additional_lib_deps() + [ + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "//third_party/eigen3", "@com_google_absl//absl/base:core_headers", @@ -2214,7 +2312,6 @@ cc_library( "lib/**/*.cc", "platform/*.cc", "platform/profile_utils/**/*.cc", - ] + [ "framework/resource_handle.cc", "util/env_var.cc", ], @@ -2232,6 +2329,7 @@ cc_library( "platform/**/logging.cc", "platform/**/human_readable_json.cc", "platform/abi.cc", + "platform/protobuf.cc", ], ) + tf_additional_lib_srcs( exclude = [ @@ -2258,6 +2356,8 @@ cc_library( ":lib_proto_parsing", ":abi", ":core_stringpiece", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "//third_party/eigen3", "//tensorflow/core/platform/default/build_config:platformlib", "@snappy", @@ -2354,7 +2454,12 @@ cc_library( cc_library( name = "tflite_portable_logging", - srcs = [], + srcs = [ + ] + if_ios([ + "platform/default/logging.cc", + "platform/env_time.cc", + "platform/posix/env_time.cc", + ]), hdrs = [ "lib/bfloat16/bfloat16.h", "platform/default/integral_types.h", @@ -2363,7 +2468,7 @@ cc_library( "platform/macros.h", "platform/platform.h", "platform/types.h", - ] + if_windows(["platform/windows/integral_types.h"]), + ] + if_windows(["platform/windows/integral_types.h"]) + if_ios(["platform/env_time.h"]), copts = tf_copts(), linkopts = ["-ldl"], deps = [ @@ -2632,7 +2737,6 @@ tf_cuda_library( "example/**/*.cc", "framework/**/*.cc", "util/**/*.cc", - ] + [ "graph/edgeset.cc", "graph/graph.cc", "graph/graph_def_builder.cc", @@ -2773,6 +2877,7 @@ cc_library( # in this library. GRAPH_HDRS = [ "graph/algorithm.h", + "graph/collective_order.h", "graph/colors.h", "graph/control_flow.h", "graph/costmodel.h", @@ -2799,6 +2904,7 @@ tf_cuda_library( name = "graph", srcs = [ "graph/algorithm.cc", + "graph/collective_order.cc", "graph/colors.cc", "graph/control_flow.cc", "graph/costmodel.cc", @@ -2816,6 +2922,9 @@ tf_cuda_library( ":proto_text", ":protos_all_cc", "//third_party/eigen3", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], ) @@ -2830,12 +2939,16 @@ CORE_CPU_BASE_HDRS = GRAPH_HDRS + [ "framework/versions.h", "common_runtime/process_function_library_runtime.h", "common_runtime/function.h", + "common_runtime/scoped_allocator.h", + "common_runtime/scoped_allocator_mgr.h", ] tf_cuda_library( name = "core_cpu_base", srcs = [ "common_runtime/eval_const_tensor.cc", + "common_runtime/scoped_allocator.cc", + "common_runtime/scoped_allocator_mgr.cc", "common_runtime/shape_refiner.cc", "common_runtime/shape_refiner.h", "framework/versions.h", @@ -2868,6 +2981,7 @@ tf_cuda_library( CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/allocator_retry.h", + "common_runtime/shared_counter.h", "common_runtime/base_collective_executor.h", "common_runtime/bfc_allocator.h", "common_runtime/hierarchical_tree_broadcaster.h", @@ -2892,9 +3006,11 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/lower_if_while.h", "common_runtime/lower_while_op.h", "common_runtime/memory_types.h", + "common_runtime/metrics.h", "common_runtime/mkl_cpu_allocator.h", "common_runtime/optimization_registry.h", "common_runtime/pending_counts.h", + "common_runtime/partitioning_utils.h", "common_runtime/placer.h", "common_runtime/process_util.h", "common_runtime/profile_handler.h", @@ -2902,8 +3018,8 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/rendezvous_mgr.h", "common_runtime/rendezvous_util.h", "common_runtime/ring_reducer.h", - "common_runtime/scoped_allocator.h", - "common_runtime/scoped_allocator_mgr.h", + "common_runtime/ring_alg.h", + "common_runtime/ring_gatherer.h", "common_runtime/session_factory.h", "common_runtime/single_threaded_cpu_device.h", "common_runtime/stats_publisher_interface.h", @@ -2928,6 +3044,8 @@ tf_cuda_library( "common_runtime/collective_param_resolver_local.cc", "common_runtime/collective_rma_local.cc", "common_runtime/collective_util.cc", + "common_runtime/colocation_graph.cc", + "common_runtime/colocation_graph.h", "common_runtime/constant_folding.cc", "common_runtime/copy_tensor.cc", "common_runtime/costmodel_manager.cc", @@ -2948,9 +3066,11 @@ tf_cuda_library( "common_runtime/lower_if_while.cc", "common_runtime/lower_while_op.cc", "common_runtime/memory_types.cc", + "common_runtime/metrics.cc", "common_runtime/mkl_cpu_allocator.cc", "common_runtime/optimization_registry.cc", "common_runtime/parallel_concat_optimizer.cc", + "common_runtime/partitioning_utils.cc", "common_runtime/placer.cc", "common_runtime/pool_allocator.cc", "common_runtime/process_function_library_runtime.cc", @@ -2959,9 +3079,9 @@ tf_cuda_library( "common_runtime/renamed_device.cc", "common_runtime/rendezvous_mgr.cc", "common_runtime/rendezvous_util.cc", + "common_runtime/ring_alg.cc", + "common_runtime/ring_gatherer.cc", "common_runtime/ring_reducer.cc", - "common_runtime/scoped_allocator.cc", - "common_runtime/scoped_allocator_mgr.cc", "common_runtime/session.cc", "common_runtime/session_factory.cc", "common_runtime/session_options.cc", @@ -2989,8 +3109,9 @@ tf_cuda_library( ":proto_text", ":protos_all_cc", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "//third_party/eigen3", - "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/utils:functions", ] + mkl_deps(), alwayslink = 1, ) @@ -3048,15 +3169,6 @@ cc_library( deps = [":lib_internal"], ) -tf_cuda_library( - name = "metrics", - srcs = ["common_runtime/metrics.cc"], - hdrs = ["common_runtime/metrics.h"], - deps = [ - ":lib", - ], -) - tf_cuda_library( name = "direct_session_internal", srcs = ["common_runtime/direct_session.cc"], @@ -3073,7 +3185,6 @@ tf_cuda_library( ":graph", ":lib", ":lib_internal", - ":metrics", ":proto_text", ":protos_all_cc", "//tensorflow/core/debug:debug_graph_utils", @@ -3440,6 +3551,7 @@ tf_cc_tests( "platform/vmodule_benchmark_test.cc", ], deps = [ + ":core_cpu_internal", ":lib", ":lib_internal", ":lib_test_internal", @@ -3505,6 +3617,29 @@ tf_cc_test( ], ) +tf_cc_test( + name = "platform_fake_python_env_test", + size = "small", + srcs = ["platform/fake_python_env_test.cc"], + args = [ + "/some/path/to/pythontest.runfiles/org_tensorflow/stuff/to/run.py", + ], + tags = [ + "local", + "no_windows", + "nogpu", + "nomac", + "notap", + ], + deps = [ + ":lib", + ":lib_internal", + ":lib_test_internal", + ":test", + ":test_main", + ], +) + tf_cc_test( name = "platform_abi_test", size = "small", @@ -3626,6 +3761,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "lib_strings_proto_serialization_test", + srcs = ["lib/strings/proto_serialization_test.cc"], + deps = [ + ":lib", + ":lib_internal", + ":lib_test_internal", + ":protos_all_cc", + ":test", + ":test_main", + "@com_google_absl//absl/memory", + ], +) + tf_cc_test( name = "lib_random_weighted_picker_test", size = "medium", @@ -3678,7 +3827,6 @@ tf_cc_tests( srcs = [ "common_runtime/buf_rendezvous_test.cc", "common_runtime/collective_executor_mgr_test.cc", - "common_runtime/collective_param_resolver_local_test.cc", "common_runtime/collective_rma_local_test.cc", "common_runtime/device_resolver_local_test.cc", "common_runtime/device_set_test.cc", @@ -3794,6 +3942,7 @@ tf_cc_tests( name = "higher_level_tests_needing_kernels", size = "small", srcs = [ + "common_runtime/collective_param_resolver_local_test.cc", "graph/graph_constructor_test.cc", ], linkopts = select({ @@ -3833,7 +3982,6 @@ tf_cc_test( "ops/cudnn_rnn_ops_test.cc", ], deps = [ - ":cudnn_rnn_ops", "//tensorflow/core", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -3843,6 +3991,27 @@ tf_cc_test( ], ) +tf_cc_tests( + name = "collective_order_test", + size = "small", + srcs = [ + "graph/collective_order_test.cc", + ], + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":test", + "@com_google_googletest//:gtest_main", + ], +) + tf_cc_tests_gpu( name = "ring_reducer_test", size = "medium", @@ -3872,6 +4041,35 @@ tf_cc_tests_gpu( ], ) +tf_cc_tests_gpu( + name = "ring_gatherer_test", + size = "medium", + srcs = [ + "common_runtime/ring_gatherer_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":all_kernels", + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + ":framework", + ":framework_internal", + ":gpu_runtime", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":protos_test_cc", + ":test", + ":test_main", + ":testlib", + "@com_google_absl//absl/memory", + ], +) + tf_cc_tests_gpu( name = "hierarchical_tree_broadcaster_test", size = "medium", @@ -4059,20 +4257,6 @@ tf_cuda_cc_test( ], ) -tf_cc_test_gpu( - name = "cuda_libdevice_path_test", - size = "small", - srcs = ["platform/cuda_libdevice_path_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), - deps = [ - ":cuda_libdevice_path", - ":lib", - ":test", - ":test_main", - ], -) - tf_cuda_only_cc_test( name = "util_cuda_kernel_helper_test", srcs = [ @@ -4206,7 +4390,7 @@ tf_cc_test( ], ) -tf_cc_test( +tf_cuda_cc_test( name = "common_runtime_process_function_library_runtime_test", size = "small", srcs = ["common_runtime/process_function_library_runtime_test.cc"], @@ -4215,6 +4399,7 @@ tf_cc_test( ":core_cpu", ":core_cpu_internal", ":framework", + ":framework_internal", ":lib", ":test", ":test_main", @@ -4223,6 +4408,7 @@ tf_cc_test( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:resource_variable_ops", ], ) @@ -4264,6 +4450,27 @@ tf_cc_test( ], ) +tf_cc_test( + name = "common_runtime_partitioning_utils_test", + size = "small", + srcs = ["common_runtime/partitioning_utils_test.cc"], + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":framework", + ":lib", + ":ops", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:identity_op", + ], +) + tf_cuda_cc_test( name = "common_runtime_direct_session_test", size = "small", @@ -4401,7 +4608,7 @@ tf_cc_test( "//tensorflow/cc:scope", "//tensorflow/core/kernels:cwise_op", "//third_party/eigen3", - ], + ] + if_mkl([":mkl_array_ops_op_lib"]), ) tf_cc_test( @@ -4928,7 +5135,7 @@ filegroup( cc_library( name = "cuda_libdevice_path", - srcs = ["platform/cuda_libdevice_path.cc"] + tf_additional_libdevice_srcs(), + srcs = tf_additional_libdevice_srcs(), hdrs = ["platform/cuda_libdevice_path.h"], copts = tf_copts(), data = tf_additional_libdevice_data(), @@ -4954,6 +5161,39 @@ transitive_hdrs( # ----------------------------------------------------------------------------- # Google-internal targets go here (must be at the end). +load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library") + +genrule( + name = "emscripten_proto_config_lite_runtime", + outs = ["emscripten_proto_config_lite_runtime.asciipb"], + cmd = tf_genrule_cmd_append_to_srcs("optimize_mode:LITE_RUNTIME"), + visibility = ["//visibility:private"], +) + +# We are keeping the "android" version of tf_android_core_proto_headers. All it does is +# normalize CORE_PROTO_SRCS to generate valid output file names. +tf_portable_proto_library( + name = "emscripten_proto_lib_no_rtti_lite_runtime", + config = ":emscripten_proto_config_lite_runtime", + copts = tf_opts_nortti_if_emscripten(), + features = tf_features_nomodules_if_emscripten(), + header_outs = tf_android_core_proto_headers(CORE_PROTO_SRCS) + ["//google/protobuf/any.proto.h"], + link_full_protobuf = False, + prefix_dir = "emscripten_proto_no_rtti", + proto_deps = [ + ":protos_all_cc", + "@protobuf_archive//:protobuf", + ], + visibility = ["//visibility:public"], +) + +# There is currently no need for a full proto version of emscripten tf lib lite. +alias( + name = "emscripten_lib_lite_no_runtime", + actual = "//tensorflow/core:emscripten_tensorflow_lib_lite_nortti_lite_protos_no_runtime", + visibility = ["//visibility:public"], +) + alias( name = "android_srcs_no_runtime", actual = ":mobile_srcs_no_runtime", diff --git a/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt b/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d6f28bd022bcd843aa3a7aeb8b1b257a3b3ddfd3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt @@ -0,0 +1,67 @@ +op { + graph_op_name: "AllToAll" + in_arg { + name: "input" + description: <